diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdb..8e150fef1 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -13,6 +13,7 @@ from colossalai.shardformer.layer import ( PaddingEmbedding, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from ..modeling.qwen2 import ( @@ -446,7 +447,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): suffix="lm_head", target_module=LinearWithGradAccum, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), - ) + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, )