fix tp bug

This commit is contained in:
Tong Li 2025-03-13 14:52:09 +08:00
parent 704866a240
commit 131eeceb5d

View File

@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
PaddingEmbedding, PaddingEmbedding,
RMSNorm, RMSNorm,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
VocabParallelLMHead1D,
) )
from ..modeling.qwen2 import ( from ..modeling.qwen2 import (
@ -446,7 +447,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
suffix="lm_head", suffix="lm_head",
target_module=LinearWithGradAccum, target_module=LinearWithGradAccum,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
) ),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
], ],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
) )