diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index d550484da..09f6e005c 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -39,6 +39,7 @@ def _get_attention_mask( attention_mask: Optional[torch.FloatTensor], encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], + head_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: # Received input is already split for non-first pipeline stages, # but attn mask isn't @@ -48,6 +49,9 @@ def _get_attention_mask( sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + _use_sdpa = self._attn_implementation == "sdpa" + print("_use_sdpa", _use_sdpa) + from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa if self.config.add_cross_attention and encoder_hidden_states is not None: assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() @@ -63,7 +67,12 @@ def _get_attention_mask( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if _use_sdpa: + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1] + ) + elif not self._attn_implementation == "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: if shard_config.enable_flash_attention: encoder_attention_mask = {"attention_mask": None} @@ -77,7 +86,6 @@ def _get_attention_mask( if shard_config.enable_flash_attention: if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) - attention_mask = ColoAttention.prepare_attn_kwargs( (batch_size, 1, seq_len, seq_len + past_key_values_length), hidden_states.dtype, @@ -85,6 +93,15 @@ def _get_attention_mask( attention_mask, is_causal=True, ) + elif self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif _use_sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, hidden_states[-1]), + inputs_embeds=hidden_states, + past_key_values_length=past_key_values_length, + ) elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") @@ -207,8 +224,10 @@ class GPT2PipelineForwards: attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -835,9 +854,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + + shape_q = (*query.shape[:-1], -1, self.head_dim) + shape_kv = (*key.shape[:-1], -1, self.head_dim) + query = query.view(shape_q).transpose(1, 2) + key = key.view(shape_kv).transpose(1, 2) + value = value.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past @@ -871,7 +893,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index c57d33826..6806260f5 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -38,13 +38,8 @@ class GPT2Policy(Policy): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model - ATTN_IMPLEMENTATION = { - "eager": GPT2Attention, - } - policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: @@ -53,6 +48,10 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding + + + print("embedding_cls", embedding_cls) + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -224,6 +223,7 @@ class GPT2Policy(Policy): if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by + print("embedding_cls", embedding_cls) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="wte", @@ -280,7 +280,7 @@ class GPT2Policy(Policy): "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, - target_key=attn_cls, + target_key=GPT2Attention, ) if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: @@ -394,12 +394,13 @@ class GPT2LMHeadModelPolicy(GPT2Policy): module_policy = super().module_policy() module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: + print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs={ - "gather_output": False, + "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, }, ), @@ -417,19 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy): target_key=GPT2LMHeadModel, ) - if self.shard_config.parallel_output: - self.append_or_create_method_replacement( - description={ - "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) - }, - policy=module_policy, - target_key=GPT2LMHeadModel, - ) + # if self.shard_config.parallel_output: + # self.append_or_create_method_replacement( + # description={ + # "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + # }, + # policy=module_policy, + # target_key=GPT2LMHeadModel, + # ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f2b139bec..8ebe283f9 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -127,43 +127,43 @@ model_zoo.register( loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True), ) -model_zoo.register( - name="transformers_gpt_lm", - model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_double_heads", - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=date_gen_for_double_heads, - output_transform_fn=output_transform_fn, - loss_fn=lambda x: x.loss + x.mc_loss, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_question_answering", - model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_token_classification", - model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_sequence_classification", - model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) +# model_zoo.register( +# name="transformers_gpt_lm", +# model_fn=lambda: transformers.GPT2LMHeadModel(config), +# data_gen_fn=data_gen_for_lm, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_double_heads", +# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), +# data_gen_fn=date_gen_for_double_heads, +# output_transform_fn=output_transform_fn, +# loss_fn=lambda x: x.loss + x.mc_loss, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_question_answering", +# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), +# data_gen_fn=data_gen_for_question_answering, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_token_classification", +# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_token_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_sequence_classification", +# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_sequence_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 393f7ffca..0334f5c2c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -157,19 +157,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "num_microbatches": 1, - "enable_all_optimization": True, + # "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, { - "tp_size": 4, + "tp_size": 2, "pp_size": 1, "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, + "enable_sequence_parallelism": False, + # "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, }, @@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, },