mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -65,7 +65,6 @@ class GPT2Policy(Policy):
|
||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
)
|
||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@@ -94,7 +93,6 @@ class GPT2Policy(Policy):
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@@ -109,7 +107,6 @@ class GPT2Policy(Policy):
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
|
Reference in New Issue
Block a user