[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-05-14 04:24:23 +00:00
parent 0dede489d6
commit 89917e247b

View File

@ -1,3 +1,4 @@
import warnings
from typing import List, Optional, Tuple, Union
import torch
@ -21,7 +22,6 @@ from transformers.models.falcon.modeling_falcon import (
build_alibi_tensor,
)
from transformers.utils import logging
import warnings
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
@ -134,12 +134,12 @@ def get_tp_falcon_decoder_layer_forward():
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
attention_output = attn_outputs[0]
@ -294,35 +294,35 @@ class FalconPipelineForwards:
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
past_key_values,
use_cache,
output_attentions,
cache_position,
position_embeddings,
)
block.__call__,
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
past_key_values,
use_cache,
output_attentions,
cache_position,
position_embeddings,
)
else:
outputs = block(
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = outputs[0]
if use_cache is True:
next_decoder_cache = outputs[1]
outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)