fix tp bug

This commit is contained in:
Tong Li 2025-03-13 14:52:09 +08:00
parent 704866a240
commit 131eeceb5d

View File

@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
PaddingEmbedding,
RMSNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.qwen2 import (
@ -446,7 +447,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
)
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)