From a9bb7cb94372f829a37d30c98ab809ee2e09508b Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 8 May 2025 16:06:05 +0800 Subject: [PATCH 1/6] upgrade command --- colossalai/shardformer/modeling/command.py | 104 ++++++++---------- colossalai/shardformer/policies/command.py | 37 +------ .../test_model/test_shard_command.py | 19 ++-- 3 files changed, 58 insertions(+), 102 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index ea811acdf..be4efbd94 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,9 +1,8 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable import torch -import torch.utils.checkpoint from torch import nn from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -15,6 +14,12 @@ from transformers.models.cohere.modeling_cohere import ( repeat_kv, ) from transformers.utils import logging +from transformers.processing_utils import Unpack +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.models.cohere.modeling_cohere import eager_attention_forward +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from functools import partial from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward @@ -27,6 +32,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] +logger = logging.get_logger(__name__) class CommandPipelineForwards: """ @@ -37,22 +43,23 @@ class CommandPipelineForwards: @staticmethod def command_model_forward( self: CohereModel, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, force_sp_output_gather: bool = True, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ): + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -67,8 +74,6 @@ class CommandPipelineForwards: ) use_cache = False - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -133,7 +138,7 @@ class CommandPipelineForwards: is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) if self.gradient_checkpointing and self.training and use_cache: if use_cache: @@ -163,6 +168,8 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 @@ -193,6 +200,7 @@ class CommandPipelineForwards: output_attentions, use_cache, cache_position, + position_embeddings ) else: layer_outputs = decoder_layer( @@ -203,6 +211,7 @@ class CommandPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings ) hidden_states = layer_outputs[0] @@ -224,17 +233,6 @@ class CommandPipelineForwards: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if stage_manager.is_last_stage(): - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -350,48 +348,44 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} + def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if sp_mode is not None: assert sp_mode in _SUPPORTED_SP_MODE, f"SP mode {sp_mode} is not supported by {type(self)} yet" assert (sp_size is not None) and ( sp_group is not None ), "Must specify sp_size and sp_group for sequence parallel" - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -403,7 +397,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -413,11 +408,14 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + attn_weights = None + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: + # to be fixed: + # precision issue attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): @@ -451,40 +449,36 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights return forward - def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, force_sp_output_gather: bool = True, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): @@ -527,7 +521,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) if sp_mode in ["ring", "split_gather"]: inputs_embeds = split_forward_gather_backward( @@ -543,6 +537,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + + position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: if output_hidden_states: @@ -557,6 +553,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions, use_cache, cache_position, + position_embeddings ) else: @@ -568,16 +565,11 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # Cases that don't support parallelizing cross entropy computation along sequence @@ -594,8 +586,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index e6e741d34..5bbaa0b38 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -41,15 +41,13 @@ class CommandPolicy(Policy): from transformers.models.cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, - CohereFlashAttention2, CohereModel, - CohereSdpaAttention, ) ATTN_IMPLEMENTATION = { "eager": CohereAttention, - "flash_attention_2": CohereFlashAttention2, - "sdpa": CohereSdpaAttention, + "flash_attention_2": CohereAttention, + "sdpa": CohereAttention, } policy = {} @@ -60,12 +58,7 @@ class CommandPolicy(Policy): else: if self.tie_weight: embedding_cls = PaddingEmbedding - - if self.shard_config.enable_fused_normalization: - norm_cls = FusedLayerNorm - else: - norm_cls = LayerNorm - + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None @@ -280,29 +273,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - # optimization configuration - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - ], - policy=policy, - target_key=CohereDecoderLayer, - ) - - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="norm", - target_module=norm_cls, - kwargs={"sp_partial_derived": sp_partial_derived}, - ), - policy=policy, - target_key=CohereModel, - ) - return policy def postprocess(self): @@ -349,6 +319,7 @@ class CommandPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 9435ef84b..51595948e 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,17 +213,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp16", - "initial_scale": 1, - }, + { "tp_size": 1, "pp_size": 1, @@ -231,6 +221,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -241,6 +232,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -252,6 +244,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "use_lazy_init": False, + "enable_flash_attention": True, "precision": "fp32", "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), @@ -260,6 +253,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -270,6 +264,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, + "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", @@ -370,4 +365,4 @@ def test_command_3d(): if __name__ == "__main__": test_command() - test_command_3d() + test_command_3d() \ No newline at end of file From 06724492ca1b73c1935d31fc30b06ea8ef62aa19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 May 2025 08:13:33 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/command.py | 38 +++++++++---------- colossalai/shardformer/policies/command.py | 12 ++---- .../test_model/test_shard_command.py | 3 +- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index be4efbd94..43f45ed3d 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,10 +1,10 @@ import math -import warnings -from typing import List, Optional, Tuple, Union, Callable +from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( CohereForCausalLM, @@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import ( apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils import logging from transformers.processing_utils import Unpack -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.models.cohere.modeling_cohere import eager_attention_forward -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - -from functools import partial +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward @@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] logger = logging.get_logger(__name__) + class CommandPipelineForwards: """ This class serves as a micro library for forward function substitution of Command models @@ -168,7 +164,7 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] @@ -200,7 +196,7 @@ class CommandPipelineForwards: output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -211,7 +207,7 @@ class CommandPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -348,7 +344,6 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} - def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, @@ -370,22 +365,22 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -409,7 +404,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = None - + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -452,11 +447,12 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - + return attn_output, attn_weights return forward + def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) @@ -537,7 +533,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: @@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: @@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 5bbaa0b38..7044f3be7 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -6,8 +6,6 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - FusedLayerNorm, - LayerNorm, Linear1D_Col, Linear1D_Row, LinearWithGradAccum, @@ -38,11 +36,7 @@ class CommandPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.cohere.modeling_cohere import ( - CohereAttention, - CohereDecoderLayer, - CohereModel, - ) + from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel ATTN_IMPLEMENTATION = { "eager": CohereAttention, @@ -58,11 +52,11 @@ class CommandPolicy(Policy): else: if self.tie_weight: embedding_cls = PaddingEmbedding - + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_partial_derived = sp_mode in ["split_gather", "ring"] + sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 51595948e..384943675 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { "tp_size": 1, "pp_size": 1, @@ -365,4 +364,4 @@ def test_command_3d(): if __name__ == "__main__": test_command() - test_command_3d() \ No newline at end of file + test_command_3d() From e78c4560c6a5d14c29a09ca022f4a344b827d939 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 8 May 2025 16:22:08 +0800 Subject: [PATCH 3/6] fix --- colossalai/shardformer/policies/command.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 7044f3be7..3cba8ded0 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -38,6 +38,7 @@ class CommandPolicy(Policy): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel + # The eager, flash_attention_2, sdpa will all be passed to CohereAttention in v4.51.3 transformers. ATTN_IMPLEMENTATION = { "eager": CohereAttention, "flash_attention_2": CohereAttention, @@ -53,10 +54,11 @@ class CommandPolicy(Policy): if self.tie_weight: embedding_cls = PaddingEmbedding + # CohereLayerNorm has no bias in v4.51.3 transformers, so we don't replace it. + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") From ba9fb549d599ab81437f3fce85eee8c0c6146e8f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 15 May 2025 17:47:21 +0800 Subject: [PATCH 4/6] fix --- colossalai/shardformer/modeling/command.py | 47 ++++------------------ 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 43f45ed3d..6c2dbb13a 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,4 +1,3 @@ -import math from typing import List, Optional, Tuple, Union import torch @@ -7,6 +6,7 @@ from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( + CohereAttention, CohereForCausalLM, CohereModel, StaticCache, @@ -346,7 +346,7 @@ class CommandPipelineForwards: def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( - self, + self: CohereAttention, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, @@ -381,25 +381,10 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -409,32 +394,16 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: - # to be fixed: - # precision issue - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + dropout = (0.0 if not self.training else self.attention_dropout,) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous() # sp: all-to-all comminucation when introducing sequence parallel From 10bc6af2b1419050d8ffb4a086c011f00c7b1cd1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 15 May 2025 17:55:24 +0800 Subject: [PATCH 5/6] fix --- colossalai/shardformer/modeling/command.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 6c2dbb13a..fe494b996 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -134,6 +134,7 @@ class CommandPipelineForwards: is_causal=True, ) else: + # v4.51.3 transformers attention_mask calculation attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_values) if self.gradient_checkpointing and self.training and use_cache: @@ -164,7 +165,7 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + # v4.51.3 transformers position_embeddings calculation position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] @@ -394,6 +395,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) else: + # attn_weights and attn_output calculation is modified on the v4.51.3 of transformers.models.cohere.modeling_cohere.CohereAttention.forward. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -486,6 +488,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode is_causal=True, ) else: + # v4.51.3 transformers attention_mask calculation attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) if sp_mode in ["ring", "split_gather"]: @@ -503,6 +506,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_self_attns = () if output_attentions else None next_decoder_cache = None + # v4.51.3 transformers position_embeddings calculation position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: From ced6b5e1c317cbf26914b1551145c19776e0589b Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 16 May 2025 11:39:50 +0800 Subject: [PATCH 6/6] fix --- colossalai/shardformer/modeling/command.py | 8 ++++++-- colossalai/shardformer/policies/command.py | 15 ++++++++------- .../test_model/test_shard_command.py | 16 +++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index fe494b996..aa11b7043 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -123,6 +123,7 @@ class CommandPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage + shard_config.enable_flash_attention = True if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length, seq_length_with_past) @@ -391,6 +392,8 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = None + shard_config.enable_flash_attention = True + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -402,7 +405,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - dropout = (0.0 if not self.training else self.attention_dropout,) + dropout = 0.0 if not self.training else self.attention_dropout attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -477,6 +480,8 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode if position_ids is None: position_ids = cache_position.unsqueeze(0) + shard_config.enable_flash_attention = True + # in this case, attention_mask is a dict rather than a tensor if shard_config.enable_flash_attention: mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) @@ -524,7 +529,6 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode cache_position, position_embeddings, ) - else: layer_outputs = decoder_layer( hidden_states, diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 3cba8ded0..8a510307e 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -75,14 +75,15 @@ class CommandPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) + + self.append_or_create_method_replacement( + description={ + "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=attn_cls, + ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: - self.append_or_create_method_replacement( - description={ - "forward": get_command_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), - }, - policy=policy, - target_key=attn_cls, - ) if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 384943675..4d156d84d 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,6 +213,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 1, "pp_size": 1, @@ -220,7 +231,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -231,7 +241,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, @@ -243,7 +252,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 4, "use_lazy_init": False, - "enable_flash_attention": True, "precision": "fp32", "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), @@ -252,7 +260,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 2, "precision": "fp16", @@ -263,7 +270,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "pp_size": 2, "num_microbatches": 2, "enable_all_optimization": True, - "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16",