[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
This commit is contained in:
Hongxin Liu
2024-10-10 14:34:45 +08:00
committed by GitHub
parent f4daf04270
commit 646b3c5a90
10 changed files with 399 additions and 157 deletions

View File

@@ -92,7 +92,7 @@ class GPT2Policy(Policy):
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
@@ -107,7 +107,7 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,