mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-01 07:46:55 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
0dede489d6
commit
89917e247b
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -21,7 +22,6 @@ from transformers.models.falcon.modeling_falcon import (
|
||||
build_alibi_tensor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
import warnings
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
@ -134,12 +134,12 @@ def get_tp_falcon_decoder_layer_forward():
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
||||
@ -294,35 +294,35 @@ class FalconPipelineForwards:
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
past_key_values,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
next_decoder_cache = outputs[1]
|
||||
outputs[1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
Loading…
Reference in New Issue
Block a user