From 06724492ca1b73c1935d31fc30b06ea8ef62aa19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 May 2025 08:13:33 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/command.py | 38 +++++++++---------- colossalai/shardformer/policies/command.py | 12 ++---- .../test_model/test_shard_command.py | 3 +- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index be4efbd94..43f45ed3d 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -1,10 +1,10 @@ import math -import warnings -from typing import List, Optional, Tuple, Union, Callable +from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( CohereForCausalLM, @@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import ( apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils import logging from transformers.processing_utils import Unpack -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.models.cohere.modeling_cohere import eager_attention_forward -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - -from functools import partial +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward @@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] logger = logging.get_logger(__name__) + class CommandPipelineForwards: """ This class serves as a micro library for forward function substitution of Command models @@ -168,7 +164,7 @@ class CommandPipelineForwards: all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] @@ -200,7 +196,7 @@ class CommandPipelineForwards: output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -211,7 +207,7 @@ class CommandPipelineForwards: output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -348,7 +344,6 @@ class CommandPipelineForwards: return {"hidden_states": hidden_states} - def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, @@ -370,22 +365,22 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, # sp: modify sp_len when sequence parallel mode is ring if sp_mode in ["split_gather", "ring"]: q_len *= sp_size - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -409,7 +404,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = None - + if shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -452,11 +447,12 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - + return attn_output, attn_weights return forward + def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) @@ -537,7 +533,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: @@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: @@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 5bbaa0b38..7044f3be7 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -6,8 +6,6 @@ from torch import Tensor from torch.nn import Module from colossalai.shardformer.layer import ( - FusedLayerNorm, - LayerNorm, Linear1D_Col, Linear1D_Row, LinearWithGradAccum, @@ -38,11 +36,7 @@ class CommandPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.cohere.modeling_cohere import ( - CohereAttention, - CohereDecoderLayer, - CohereModel, - ) + from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel ATTN_IMPLEMENTATION = { "eager": CohereAttention, @@ -58,11 +52,11 @@ class CommandPolicy(Policy): else: if self.tie_weight: embedding_cls = PaddingEmbedding - + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None - sp_partial_derived = sp_mode in ["split_gather", "ring"] + sp_mode in ["split_gather", "ring"] if sp_mode == "ring_attn" and not self.is_causal: raise ValueError("Ring attention is only meant for causal language modeling.") diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 51595948e..384943675 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { "tp_size": 1, "pp_size": 1, @@ -365,4 +364,4 @@ def test_command_3d(): if __name__ == "__main__": test_command() - test_command_3d() \ No newline at end of file + test_command_3d()