This commit is contained in:
wangbluo 2025-05-08 16:22:08 +08:00
parent 06724492ca
commit e78c4560c6

View File

@ -38,6 +38,7 @@ class CommandPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
# The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers.
ATTN_IMPLEMENTATION = { ATTN_IMPLEMENTATION = {
"eager": CohereAttention, "eager": CohereAttention,
"flash_attention_2": CohereAttention, "flash_attention_2": CohereAttention,
@ -53,10 +54,11 @@ class CommandPolicy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
# CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it.
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal: if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.") raise ValueError("Ring attention is only meant for causal language modeling.")