[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism:
@@ -94,7 +93,6 @@ class GPT2Policy(Policy):
kwargs={
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -109,7 +107,6 @@ class GPT2Policy(Policy):
kwargs={
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
"overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},