diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index ea811acdf..be4efbd94 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,9 +1,8 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch -import torch.utils.checkpoint from torch import nn from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -15,6 +14,12 @@ from transformers.models.cohere.modeling_cohere import ( repeat_kv, ) from transformers.utils import logging +from transformers.processing_utils import Unpack +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.models.cohere.modeling_cohere import eager_attention_forward +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from functools import partial from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward @@ -27,6 +32,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] +logger = logging.get_logger(__name__) class CommandPipelineForwards: """ @@ -37,22 +43,23 @@ class CommandPipelineForwards: @staticmethod def command_model_forward( self: CohereModel, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, force_sp_output_gather: bool = True, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ): + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -67,8 +74,6 @@ class CommandPipelineForwards: ) use_cache = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -133,7 +138,7 @@ class CommandPipelineForwards: is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) if self.gradient_checkpointing and self.training and use_cache: if use_cache: @@ -163,6 +168,8 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 @@ -193,6 +200,7 @@ class CommandPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings ) else: layer_outputs = decoder_layer( @@ -203,6 +211,7 @@ class CommandPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings ) hidden_states = layer_outputs[0] @@ -224,17 +233,6 @@ class CommandPipelineForwards: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -350,48 +348,44 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} + def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if sp_mode is not None: assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -403,7 +397,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -413,11 +408,14 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + attn_weights = None + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: + # to be fixed: + # precision issue attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): @@ -451,40 +449,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights return forward - def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, force_sp_output_gather: bool = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): @@ -527,7 +521,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward( @@ -543,6 +537,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: if output_hidden_states: @@ -557,6 +553,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions, use_cache, cache_position, + position_embeddings ) else: @@ -568,16 +565,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # Cases that don't support parallelizing cross entropy computation along sequence @@ -594,8 +586,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index e6e741d34..5bbaa0b38 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -41,15 +41,13 @@ class CommandPolicy(Policy): from transformers.models.cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, - CohereFlashAttention2, CohereModel, - CohereSdpaAttention, ) ATTN_IMPLEMENTATION = { "eager": CohereAttention, - "flash_attention_2": CohereFlashAttention2, - "sdpa": CohereSdpaAttention, + "flash_attention_2": CohereAttention, + "sdpa": CohereAttention, } policy = {} @@ -60,12 +58,7 @@ class CommandPolicy(Policy): else: if self.tie_weight: embedding_cls = PaddingEmbedding - - if self.shard_config.enable_fused_normalization: - norm_cls = FusedLayerNorm - else: - norm_cls = LayerNorm - + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None @@ -280,29 +273,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - # optimization configuration - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - ], - policy=policy, - target_key=CohereDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - policy=policy, - target_key=CohereModel, - ) - return policy def postprocess(self): @@ -349,6 +319,7 @@ class CommandPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 9435ef84b..51595948e 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,17 +213,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, + { "tp_size": 1, "pp_size": 1, @@ -231,6 +221,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -241,6 +232,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -252,6 +244,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "use_lazy_init": False, + "enable_flash_attention": True, "precision": "fp32", "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), @@ -260,6 +253,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -270,6 +264,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", @@ -370,4 +365,4 @@ def test_command_3d(): if __name__ == "__main__": test_command() - test_command_3d() + test_command_3d() \ No newline at end of file