diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 870c9272b..7a8aec37d 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional, Tuple, Union import torch @@ -21,7 +22,6 @@ from transformers.models.falcon.modeling_falcon import ( build_alibi_tensor, ) from transformers.utils import logging -import warnings from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig @@ -134,12 +134,12 @@ def get_tp_falcon_decoder_layer_forward(): attention_mask=attention_mask, position_ids=position_ids, alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) attention_output = attn_outputs[0] @@ -294,35 +294,35 @@ class FalconPipelineForwards: if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - position_ids, - head_mask[i], - past_key_values, - use_cache, - output_attentions, - cache_position, - position_embeddings, - ) + block.__call__, + hidden_states, + alibi, + causal_mask, + position_ids, + head_mask[i], + past_key_values, + use_cache, + output_attentions, + cache_position, + position_embeddings, + ) else: outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = outputs[0] if use_cache is True: - next_decoder_cache = outputs[1] + outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)