diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index c83142deb..d7d182762 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -85,8 +85,6 @@ def _get_attention_mask( attention_mask, is_causal=True, ) - elif self._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index b6370d632..a0c73a5e7 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -47,7 +47,10 @@ class GPT2Policy(Policy): if self.tie_weight: embedding_cls = col_nn.PaddingEmbedding +<<<<<<< Updated upstream print("embedding_cls", embedding_cls) +======= +>>>>>>> Stashed changes if self.shard_config.enable_fused_normalization: norm_cls = col_nn.FusedLayerNorm @@ -220,7 +223,6 @@ class GPT2Policy(Policy): if embedding_cls is not None: # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by - print("embedding_cls", embedding_cls) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="wte", @@ -391,7 +393,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy): module_policy = super().module_policy() module_policy[GPT2LMHeadModel] = ModulePolicyDescription() if self.shard_config.enable_tensor_parallelism: - print("self.shard_config.enable_tensor_parallelism", self.shard_config.enable_tensor_parallelism) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="lm_head", @@ -428,7 +429,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): self.set_pipeline_forward( model_cls=GPT2LMHeadModel, new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - # shard_config=self.shard_config, + shard_config=self.shard_config, policy=module_policy, ) return module_policy diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3062eaf40..b67c494a6 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -165,11 +165,11 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, - "use_lazy_init": False, + "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, },