diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 43f45ed3d..6c2dbb13a 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,4 +1,3 @@ -import math from typing import List, Optional, Tuple, Union import torch @@ -7,6 +6,7 @@ from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( + CohereAttention, CohereForCausalLM, CohereModel, StaticCache, @@ -346,7 +346,7 @@ class CommandPipelineForwards: def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( - self, + self: CohereAttention, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, @@ -381,25 +381,10 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -409,32 +394,16 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: - # to be fixed: - # precision issue - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + dropout = (0.0 if not self.training else self.attention_dropout,) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel