[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-05-08 08:13:33 +00:00
parent a9bb7cb943
commit 06724492ca
3 changed files with 21 additions and 32 deletions

View File

@ -1,10 +1,10 @@
import math import math
import warnings from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Callable
import torch import torch
from torch import nn from torch import nn
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM, CohereForCausalLM,
@ -13,13 +13,8 @@ from transformers.models.cohere.modeling_cohere import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging
from transformers.processing_utils import Unpack from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.utils import logging
from transformers.models.cohere.modeling_cohere import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from functools import partial
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
@ -34,6 +29,7 @@ _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class CommandPipelineForwards: class CommandPipelineForwards:
""" """
This class serves as a micro library for forward function substitution of Command models This class serves as a micro library for forward function substitution of Command models
@ -200,7 +196,7 @@ class CommandPipelineForwards:
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@ -211,7 +207,7 @@ class CommandPipelineForwards:
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@ -348,7 +344,6 @@ class CommandPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self, self,
@ -457,6 +452,7 @@ def get_command_flash_attention_forward(shard_config: ShardConfig, sp_mode=None,
return forward return forward
def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -553,7 +549,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings position_embeddings,
) )
else: else:
@ -565,7 +561,7 @@ def get_command_flash_attention_model_forward(shard_config: ShardConfig, sp_mode
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@ -6,8 +6,6 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import ( from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum, LinearWithGradAccum,
@ -38,11 +36,7 @@ class CommandPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.cohere.modeling_cohere import ( from transformers.models.cohere.modeling_cohere import CohereAttention, CohereDecoderLayer, CohereModel
CohereAttention,
CohereDecoderLayer,
CohereModel,
)
ATTN_IMPLEMENTATION = { ATTN_IMPLEMENTATION = {
"eager": CohereAttention, "eager": CohereAttention,
@ -62,7 +56,7 @@ class CommandPolicy(Policy):
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_mode in ["split_gather", "ring"]
if sp_mode == "ring_attn" and not self.is_causal: if sp_mode == "ring_attn" and not self.is_causal:
raise ValueError("Ring attention is only meant for causal language modeling.") raise ValueError("Ring attention is only meant for causal language modeling.")

View File

@ -213,7 +213,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,