shardformer fp8

This commit is contained in:
GuangyaoZhang
2024-07-08 07:04:48 +00:00
parent 51f916b11d
commit 457a0de79f
16 changed files with 520 additions and 234 deletions

View File

@@ -110,14 +110,13 @@ class GPT2Policy(Policy):
"n_fused": 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
@@ -127,14 +126,13 @@ class GPT2Policy(Policy):
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
},
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",