mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -118,18 +118,22 @@ class DeepseekPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -138,7 +142,10 @@ class DeepseekPolicy(Policy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs={
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
target_key="DeepseekModel",
|
||||
@@ -155,6 +162,7 @@ class DeepseekPolicy(Policy):
|
||||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
@@ -305,7 +313,7 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
Reference in New Issue
Block a user