mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
@@ -10,6 +10,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.deepseek import (
|
||||
DeepseekMoEGate_Col,
|
||||
DeepseekPipelineForwards,
|
||||
EPDeepseekMoE,
|
||||
get_deepseek_flash_attention_forward,
|
||||
@@ -56,16 +57,24 @@ class DeepseekPolicy(Policy):
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
# modified for both SP and TP
|
||||
num_q_heads = self.model.config.num_attention_heads
|
||||
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
|
||||
if sp_mode == "all_to_all":
|
||||
num_q_heads //= sp_size
|
||||
decoder_attribute_replacement = {
|
||||
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||
"num_heads": num_q_heads,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||
num_kv_heads //= sp_size
|
||||
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
|
||||
@@ -97,6 +106,7 @@ class DeepseekPolicy(Policy):
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# tensor parallelism for non-moe params
|
||||
assert (
|
||||
@@ -107,10 +117,15 @@ class DeepseekPolicy(Policy):
|
||||
), f"The number of key_value heads must be divisible by tensor parallel size."
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
num_q_heads //= tp_size
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": num_q_heads,
|
||||
}
|
||||
if num_kv_heads:
|
||||
num_kv_heads //= tp_size
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
|
||||
|
||||
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
@@ -135,8 +150,19 @@ class DeepseekPolicy(Policy):
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate",
|
||||
target_module=DeepseekMoEGate_Col,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"config": self.model.config,
|
||||
},
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
Reference in New Issue
Block a user