[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -57,7 +57,6 @@ class BloomPolicy(Policy):
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),