diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 7044f3be7..3cba8ded0 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -38,6 +38,7 @@ class CommandPolicy(Policy): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel + # The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers. ATTN_IMPLEMENTATION = { "eager": CohereAttention, "flash_attention_2": CohereAttention, @@ -53,10 +54,11 @@ class CommandPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + # CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it. + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.")