[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -88,30 +88,51 @@ class MistralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
],
)
@@ -121,7 +142,14 @@ class MistralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=MistralModel,
@@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
)
]
@@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
kwargs=dict(
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
)
]
)
@@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
MistralForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
)
]
)