diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index ea811acdf..aa11b7043 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,19 +1,19 @@ -import math -import warnings from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( + CohereAttention, CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv, ) +from transformers.processing_utils import Unpack from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -27,6 +27,8 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] +logger = logging.get_logger(__name__) + class CommandPipelineForwards: """ @@ -37,22 +39,23 @@ class CommandPipelineForwards: @staticmethod def command_model_forward( self: CohereModel, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, force_sp_output_gather: bool = True, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ): + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -67,8 +70,6 @@ class CommandPipelineForwards: ) use_cache = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -122,6 +123,7 @@ class CommandPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage + shard_config.enable_flash_attention = True if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length, seq_length_with_past) @@ -133,7 +135,8 @@ class CommandPipelineForwards: is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + # v4.51.3 transformers attention_mask calculation + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) if self.gradient_checkpointing and self.training and use_cache: if use_cache: @@ -163,6 +166,8 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + # v4.51.3 transformers position_embeddings calculation + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 @@ -193,6 +198,7 @@ class CommandPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -203,6 +209,7 @@ class CommandPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -224,17 +231,6 @@ class CommandPipelineForwards: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -352,27 +348,22 @@ class CommandPipelineForwards: def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( - self, + self: CohereAttention, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if sp_mode is not None: assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size @@ -388,60 +379,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) + cos, sin = position_embeddings - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = None + + shard_config.enable_flash_attention = True if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - + # attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward. + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + dropout = 0.0 if not self.training else self.attention_dropout + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel @@ -451,13 +418,11 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights return forward @@ -467,24 +432,23 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, force_sp_output_gather: bool = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): @@ -516,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode if position_ids is None: position_ids = cache_position.unsqueeze(0) + shard_config.enable_flash_attention = True + # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) @@ -527,7 +493,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # v4.51.3 transformers attention_mask calculation + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward( @@ -544,6 +511,9 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_self_attns = () if output_attentions else None next_decoder_cache = None + # v4.51.3 transformers position_embeddings calculation + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -557,8 +527,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions, use_cache, cache_position, + position_embeddings, ) - else: layer_outputs = decoder_layer( hidden_states, @@ -568,16 +538,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # Cases that don't support parallelizing cross entropy computation along sequence @@ -594,8 +559,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index e6e741d34..8a510307e 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -6,8 +6,6 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - FusedLayerNorm, - LayerNorm, Linear1D_Col, Linear1D_Row, LinearWithGradAccum, @@ -38,18 +36,13 @@ class CommandPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.cohere.modeling_cohere import ( - CohereAttention, - CohereDecoderLayer, - CohereFlashAttention2, - CohereModel, - CohereSdpaAttention, - ) + from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel + # The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers. ATTN_IMPLEMENTATION = { "eager": CohereAttention, - "flash_attention_2": CohereFlashAttention2, - "sdpa": CohereSdpaAttention, + "flash_attention_2": CohereAttention, + "sdpa": CohereAttention, } policy = {} @@ -61,15 +54,11 @@ class CommandPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding - if self.shard_config.enable_fused_normalization: - norm_cls = FusedLayerNorm - else: - norm_cls = LayerNorm + # CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it. sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_partial_derived = sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") @@ -86,14 +75,15 @@ class CommandPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: - self.append_or_create_method_replacement( - description={ - "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=attn_cls, - ) if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ @@ -280,29 +270,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - # optimization configuration - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - ], - policy=policy, - target_key=CohereDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - policy=policy, - target_key=CohereModel, - ) - return policy def postprocess(self): @@ -349,6 +316,7 @@ class CommandPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 9435ef84b..4d156d84d 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -218,7 +218,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 1, "num_microbatches": 1, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": False, "use_lazy_init": True, "precision": "fp16",