diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 1280efaec..25e5b66dc 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -54,7 +54,6 @@ class GPTJPolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") - use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: @@ -78,7 +77,6 @@ class GPTJPolicy(Policy): suffix="attn.k_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), @@ -86,7 +84,6 @@ class GPTJPolicy(Policy): suffix="attn.q_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), @@ -94,24 +91,20 @@ class GPTJPolicy(Policy): suffix="attn.v_proj", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="attn.out_proj", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_in", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="mlp.fc_out", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout",