[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)

* [tp] hotfix linear row

* [tp] support uneven split for fused linear

* [tp] support sp for fused linear

* [tp] fix gpt2 mlp policy

* [tp] fix gather fused and add fused linear row
This commit is contained in:
Hongxin Liu
2024-10-10 14:34:45 +08:00
committed by GitHub
parent f4daf04270
commit 646b3c5a90
10 changed files with 399 additions and 157 deletions

View File

@@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",