mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
kwargs=dict(
|
||||
gather_output=False,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
@@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
@@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
@@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
@@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model,
|
||||
@@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
@@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
@@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel,
|
||||
|
Reference in New Issue
Block a user