[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-05-08 09:58:04 +00:00
parent 5480b811c5
commit 4eced5cf8a

View File

@ -1,17 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -159,7 +151,7 @@ def get_tp_falcon_decoder_layer_forward():
and self.config.num_ln_in_parallel_attn == 1 and self.config.num_ln_in_parallel_attn == 1
): ):
mlp_layernorm_out = attention_layernorm_out mlp_layernorm_out = attention_layernorm_out
outputs = attn_outputs[1:] outputs = attn_outputs[1:]
# MLP. # MLP.
@ -215,7 +207,6 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
logger.warning_once("past_key_values is not supported for pipeline models at the moment.") logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None past_key_values = None
@ -251,7 +242,7 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
alibi = None alibi = None
past_key_values_length = 0 past_key_values_length = 0
batch_size, seq_length, _ = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
if self.use_alibi: if self.use_alibi:
mask = ( mask = (
@ -262,7 +253,7 @@ class FalconPipelineForwards:
else attention_mask else attention_mask
) )
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
if cache_position is None: if cache_position is None:
cache_position = torch.arange( cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
@ -280,7 +271,7 @@ class FalconPipelineForwards:
# attention_probs has shape batch_size x num_heads x N x N # attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# create position embeddings to be shared across the decoder layers # create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = self.rotary_emb(hidden_states, position_ids)
@ -319,7 +310,7 @@ class FalconPipelineForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
next_decoder_cache = outputs[1] outputs[1]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -332,7 +323,7 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None