1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-04 06:28:05 +00:00
This commit is contained in:
flybird11111 2024-04-26 11:52:27 +08:00 committed by GitHub
parent 1b387ca9fe
commit 8b7d535977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",