From e78c4560c6a5d14c29a09ca022f4a344b827d939 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 8 May 2025 16:22:08 +0800 Subject: [PATCH] fix --- colossalai/shardformer/policies/command.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 7044f3be7..3cba8ded0 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -38,6 +38,7 @@ class CommandPolicy(Policy): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel + # The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers. ATTN_IMPLEMENTATION = { "eager": CohereAttention, "flash_attention_2": CohereAttention, @@ -53,10 +54,11 @@ class CommandPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + # CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it. + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.")