mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 13:42:05 +00:00
fix tp bug
This commit is contained in:
parent
704866a240
commit
131eeceb5d
@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
|
|||||||
PaddingEmbedding,
|
PaddingEmbedding,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
VocabParallelEmbedding1D,
|
VocabParallelEmbedding1D,
|
||||||
|
VocabParallelLMHead1D,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..modeling.qwen2 import (
|
from ..modeling.qwen2 import (
|
||||||
@ -446,7 +447,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=LinearWithGradAccum,
|
target_module=LinearWithGradAccum,
|
||||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||||
)
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=VocabParallelLMHead1D,
|
||||||
|
kwargs={
|
||||||
|
"gather_output": not self.shard_config.parallel_output,
|
||||||
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user