[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-05-08 08:13:33 +00:00
parent a9bb7cb943
commit 06724492ca
3 changed files with 21 additions and 32 deletions

View File

@ -1,10 +1,10 @@
import math import math
import warnings from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Callable
import torch import torch
from torch import nn from torch import nn
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM, CohereForCausalLM,
@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging
from transformers.models.cohere.modeling_cohere import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from functools import partial
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
This class serves as a micro library for forward function substitution of Command models This class serves as a micro library for forward function substitution of Command models
@ -168,7 +164,7 @@ class CommandPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
@ -200,7 +196,7 @@ class CommandPipelineForwards:
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -211,7 +207,7 @@ class CommandPipelineForwards:
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -348,7 +344,6 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
@ -370,22 +365,22 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
# sp: modify sp_len when sequence parallel mode is ring # sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]: if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size q_len *= sp_size
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size() bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: if self.layer_idx is None:
@ -409,7 +404,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = None attn_weights = None
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
@ -452,11 +447,12 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights return attn_output, attn_weights
return forward return forward
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -537,7 +533,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers: for decoder_layer in self.layers:
@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings position_embeddings,
) )
else: else:
@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -6,8 +6,6 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum, LinearWithGradAccum,
@ -38,11 +36,7 @@ class CommandPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
CohereAttention,
CohereDecoderLayer,
CohereModel,
)
ATTN_IMPLEMENTATION = { ATTN_IMPLEMENTATION = {
"eager": CohereAttention, "eager": CohereAttention,
@ -58,11 +52,11 @@ class CommandPolicy(Policy):
else: else:
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal: if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.") raise ValueError("Ring attention is only meant for causal language modeling.")

View File

@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
@ -365,4 +364,4 @@ def test_command_3d():
if __name__ == "__main__": if __name__ == "__main__":
test_command() test_command()
test_command_3d() test_command_3d()