From fdcc3691fa963c64c3375264b4a2d8fa9018e6eb Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 29 Apr 2025 18:54:33 +0800 Subject: [PATCH 1/4] fix --- colossalai/shardformer/modeling/gpt2.py | 36 +++++++-- colossalai/shardformer/policies/gpt2.py | 32 ++++---- tests/kit/model_zoo/transformers/gpt.py | 80 +++++++++---------- .../test_model/test_shard_gpt2.py | 14 ++-- 4 files changed, 94 insertions(+), 68 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index d550484da..09f6e005c 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -39,6 +39,7 @@ def _get_attention_mask( attention_mask: Optional[torch.FloatTensor], encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], + head_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: # Received input is already split for non-first pipeline stages, # but attn mask isn't @@ -48,6 +49,9 @@ def _get_attention_mask( sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + _use_sdpa = self._attn_implementation == "sdpa" + print("_use_sdpa", _use_sdpa) + from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa if self.config.add_cross_attention and encoder_hidden_states is not None: assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() @@ -63,7 +67,12 @@ def _get_attention_mask( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if _use_sdpa: + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1] + ) + elif not self._attn_implementation == "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: if shard_config.enable_flash_attention: encoder_attention_mask = {"attention_mask": None} @@ -77,7 +86,6 @@ def _get_attention_mask( if shard_config.enable_flash_attention: if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) - attention_mask = ColoAttention.prepare_attn_kwargs( (batch_size, 1, seq_len, seq_len + past_key_values_length), hidden_states.dtype, @@ -85,6 +93,15 @@ def _get_attention_mask( attention_mask, is_causal=True, ) + elif self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif _use_sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, hidden_states[-1]), + inputs_embeds=hidden_states, + past_key_values_length=past_key_values_length, + ) elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") @@ -207,8 +224,10 @@ class GPT2PipelineForwards: attention_mask, encoder_hidden_states, encoder_attention_mask, + head_mask ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -835,9 +854,12 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + + shape_q = (*query.shape[:-1], -1, self.head_dim) + shape_kv = (*key.shape[:-1], -1, self.head_dim) + query = query.view(shape_q).transpose(1, 2) + key = key.view(shape_kv).transpose(1, 2) + value = value.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past @@ -871,7 +893,9 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index c57d33826..6806260f5 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -38,13 +38,8 @@ class GPT2Policy(Policy): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model - ATTN_IMPLEMENTATION = { - "eager": GPT2Attention, - } - policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: @@ -53,6 +48,10 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding + + + print("embedding_cls", embedding_cls) + if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm else: @@ -224,6 +223,7 @@ class GPT2Policy(Policy): if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by + print("embedding_cls", embedding_cls) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="wte", @@ -280,7 +280,7 @@ class GPT2Policy(Policy): "forward": get_gpt2_flash_attention_forward(shard_config=self.shard_config), }, policy=policy, - target_key=attn_cls, + target_key=GPT2Attention, ) if not self.shard_config.pipeline_stage_manager and self.shard_config.enable_sequence_parallelism: @@ -394,12 +394,13 @@ class GPT2LMHeadModelPolicy(GPT2Policy): module_policy = super().module_policy() module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: + print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs={ - "gather_output": False, + "gather_output": True, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, }, ), @@ -417,19 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy): target_key=GPT2LMHeadModel, ) - if self.shard_config.parallel_output: - self.append_or_create_method_replacement( - description={ - "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) - }, - policy=module_policy, - target_key=GPT2LMHeadModel, - ) + # if self.shard_config.parallel_output: + # self.append_or_create_method_replacement( + # description={ + # "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + # }, + # policy=module_policy, + # target_key=GPT2LMHeadModel, + # ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index f2b139bec..8ebe283f9 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -127,43 +127,43 @@ model_zoo.register( loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True), ) -model_zoo.register( - name="transformers_gpt_lm", - model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_double_heads", - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=date_gen_for_double_heads, - output_transform_fn=output_transform_fn, - loss_fn=lambda x: x.loss + x.mc_loss, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_question_answering", - model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_token_classification", - model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_sequence_classification", - model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) +# model_zoo.register( +# name="transformers_gpt_lm", +# model_fn=lambda: transformers.GPT2LMHeadModel(config), +# data_gen_fn=data_gen_for_lm, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_double_heads", +# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), +# data_gen_fn=date_gen_for_double_heads, +# output_transform_fn=output_transform_fn, +# loss_fn=lambda x: x.loss + x.mc_loss, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_question_answering", +# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), +# data_gen_fn=data_gen_for_question_answering, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_token_classification", +# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_token_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_sequence_classification", +# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_sequence_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 393f7ffca..0334f5c2c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -157,19 +157,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "num_microbatches": 1, - "enable_all_optimization": True, + # "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, { - "tp_size": 4, + "tp_size": 2, "pp_size": 1, "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": True, - "use_lazy_init": True, + "enable_sequence_parallelism": False, + # "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, }, @@ -180,7 +180,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, }, From f8caea7762b0600124caff88def213660157f35f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 1 May 2025 09:00:13 +0800 Subject: [PATCH 2/4] fix --- colossalai/shardformer/modeling/gpt2.py | 22 +---- colossalai/shardformer/policies/gpt2.py | 20 ++--- tests/kit/model_zoo/transformers/gpt.py | 80 +++++++++---------- .../test_model/test_shard_gpt2.py | 16 ++-- 4 files changed, 61 insertions(+), 77 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 09f6e005c..bb493be3e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -39,7 +39,6 @@ def _get_attention_mask( attention_mask: Optional[torch.FloatTensor], encoder_hidden_states: Optional[torch.Tensor], encoder_attention_mask: Optional[torch.FloatTensor], - head_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: # Received input is already split for non-first pipeline stages, # but attn mask isn't @@ -49,9 +48,7 @@ def _get_attention_mask( sp_mode = shard_config.sequence_parallelism_mode # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - _use_sdpa = self._attn_implementation == "sdpa" - print("_use_sdpa", _use_sdpa) - from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa + if self.config.add_cross_attention and encoder_hidden_states is not None: assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only." encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() @@ -59,7 +56,7 @@ def _get_attention_mask( encoder_attention_mask = ColoAttention.prepare_attn_kwargs( (encoder_batch_size, 1, seq_len, encoder_sequence_length), dtype=hidden_states.dtype, - dtype2=encoder_hidden_states.dtype, + device=encoder_hidden_states.device, q_padding_mask=attention_mask, kv_padding_mask=encoder_attention_mask, ) @@ -67,12 +64,7 @@ def _get_attention_mask( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) - if _use_sdpa: - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - mask=encoder_attention_mask, dtype=hidden_states.dtype, tgt_len=encoder_hidden_shape[-1] - ) - elif not self._attn_implementation == "flash_attention_2": - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: if shard_config.enable_flash_attention: encoder_attention_mask = {"attention_mask": None} @@ -95,13 +87,6 @@ def _get_attention_mask( ) elif self._attn_implementation == "flash_attention_2": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif _use_sdpa: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask=attention_mask, - input_shape=(batch_size, hidden_states[-1]), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, - ) elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") @@ -224,7 +209,6 @@ class GPT2PipelineForwards: attention_mask, encoder_hidden_states, encoder_attention_mask, - head_mask ) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6806260f5..d5d97fd2d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -400,7 +400,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): suffix="lm_head", target_module=col_nn.VocabParallelLMHead1D, kwargs={ - "gather_output": True, + "gather_output": False, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, }, ), @@ -418,20 +418,20 @@ class GPT2LMHeadModelPolicy(GPT2Policy): target_key=GPT2LMHeadModel, ) - # if self.shard_config.parallel_output: - # self.append_or_create_method_replacement( - # description={ - # "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) - # }, - # policy=module_policy, - # target_key=GPT2LMHeadModel, - # ) + if self.shard_config.parallel_output: + self.append_or_create_method_replacement( + description={ + "forward": partial(GPT2PipelineForwards.gpt2_lmhead_model_forward, shard_config=self.shard_config) + }, + policy=module_policy, + target_key=GPT2LMHeadModel, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - shard_config=self.shard_config, + # shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 8ebe283f9..f2b139bec 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -127,43 +127,43 @@ model_zoo.register( loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True), ) -# model_zoo.register( -# name="transformers_gpt_lm", -# model_fn=lambda: transformers.GPT2LMHeadModel(config), -# data_gen_fn=data_gen_for_lm, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_double_heads", -# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), -# data_gen_fn=date_gen_for_double_heads, -# output_transform_fn=output_transform_fn, -# loss_fn=lambda x: x.loss + x.mc_loss, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_question_answering", -# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), -# data_gen_fn=data_gen_for_question_answering, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_token_classification", -# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), -# data_gen_fn=data_gen_for_token_classification, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) -# model_zoo.register( -# name="transformers_gpt_for_sequence_classification", -# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), -# data_gen_fn=data_gen_for_sequence_classification, -# output_transform_fn=output_transform_fn, -# loss_fn=loss_fn, -# model_attribute=ModelAttribute(has_control_flow=True), -# ) +model_zoo.register( + name="transformers_gpt_lm", + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_double_heads", + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_question_answering", + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_token_classification", + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_sequence_classification", + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0334f5c2c..3062eaf40 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -157,18 +157,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "num_microbatches": 1, - # "enable_all_optimization": True, + "enable_all_optimization": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, { - "tp_size": 2, + "tp_size": 4, "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": False, - # "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, @@ -238,7 +238,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32", "initial_scale": 1, @@ -247,7 +247,7 @@ def run_gpt2_test(test_config): "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": False, + "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp16", "zero_stage": 1, From 840b9f3266e6229a2f688e61031d48854f485aa2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 May 2025 01:01:52 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/gpt2.py | 3 +-- colossalai/shardformer/policies/gpt2.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bb493be3e..c83142deb 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -211,7 +211,6 @@ class GPT2PipelineForwards: encoder_attention_mask, ) - if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -877,7 +876,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) - + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d5d97fd2d..b6370d632 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,7 +40,6 @@ class GPT2Policy(Policy): policy = {} - embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = col_nn.VocabParallelEmbedding1D @@ -48,8 +47,6 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding - - print("embedding_cls", embedding_cls) if self.shard_config.enable_fused_normalization: From 519c2d0eab6ae96763404fddc96ca6b741dceda6 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 1 May 2025 09:04:24 +0800 Subject: [PATCH 4/4] fix --- colossalai/shardformer/modeling/gpt2.py | 2 -- colossalai/shardformer/policies/gpt2.py | 7 ++++--- tests/test_shardformer/test_model/test_shard_gpt2.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c83142deb..d7d182762 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -85,8 +85,6 @@ def _get_attention_mask( attention_mask, is_causal=True, ) - elif self._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index b6370d632..a0c73a5e7 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -47,7 +47,10 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding +<<<<<<< Updated upstream print("embedding_cls", embedding_cls) +======= +>>>>>>> Stashed changes if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm @@ -220,7 +223,6 @@ class GPT2Policy(Policy): if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by - print("embedding_cls", embedding_cls) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="wte", @@ -391,7 +393,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy): module_policy = super().module_policy() module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: - print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", @@ -428,7 +429,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - # shard_config=self.shard_config, + shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3062eaf40..b67c494a6 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -165,11 +165,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, - "use_lazy_init": False, + "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, },