mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear row
This commit is contained in:
@@ -92,7 +92,7 @@ class GPT2Policy(Policy):
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
@@ -107,7 +107,7 @@ class GPT2Policy(Policy):
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
|
Reference in New Issue
Block a user