[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -98,6 +98,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@@ -106,6 +107,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@@ -114,6 +116,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
@@ -123,7 +126,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@@ -136,12 +142,16 @@ class BertPolicy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel_mode": sp_mode},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@@ -180,6 +190,13 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
kwargs=(
{
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {}
),
)
],
policy=policy,
@@ -249,6 +266,7 @@ class BertPolicy(Policy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=base_policy,