mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
Merge pull request #6012 from hpcaitech/feature/fp8_comm
[fp8] support fp8 communication and fp8 training for Colossalai
This commit is contained in:
@@ -133,37 +133,37 @@ class LlamaPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -173,7 +173,14 @@ class LlamaPolicy(Policy):
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
@@ -316,6 +323,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
@@ -384,7 +392,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
Reference in New Issue
Block a user