[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="q",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="k",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="v",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="o",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(gather_output=False),
kwargs=dict(
gather_output=False,
fp8_communication=self.shard_config.fp8_communication,
),
ignore_if_not_exist=True,
),
],
@@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="wo",
target_module=Linear1D_Col,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
),
),
SubModuleReplacementDescription(
suffix="dropout",
@@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
SubModuleReplacementDescription(
suffix="wi",
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="wo",
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="dropout",
@@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Stack,
@@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5Model,
@@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5ForConditionalGeneration,
@@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
),
policy=policy,
@@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=T5EncoderModel,