From 10bc6af2b1419050d8ffb4a086c011f00c7b1cd1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 15 May 2025 17:55:24 +0800 Subject: [PATCH] fix --- colossalai/shardformer/modeling/command.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 6c2dbb13a..fe494b996 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -134,6 +134,7 @@ class CommandPipelineForwards: is_causal=True, ) else: + # v4.51.3 transformers attention_mask calculation attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) if self.gradient_checkpointing and self.training and use_cache: @@ -164,7 +165,7 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + # v4.51.3 transformers position_embeddings calculation position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] @@ -394,6 +395,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: + # attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -486,6 +488,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode is_causal=True, ) else: + # v4.51.3 transformers attention_mask calculation attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) if sp_mode in ["ring", "split_gather"]: @@ -503,6 +506,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_self_attns = () if output_attentions else None next_decoder_cache = None + # v4.51.3 transformers position_embeddings calculation position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: