mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
Reference in New Issue
Block a user