diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index c2802063f..27461be04 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,17 +1,9 @@ -import math -import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -159,7 +151,7 @@ def get_tp_falcon_decoder_layer_forward(): and self.config.num_ln_in_parallel_attn == 1 ): mlp_layernorm_out = attention_layernorm_out - + outputs = attn_outputs[1:] # MLP. @@ -215,7 +207,6 @@ class FalconPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - logger.warning_once("past_key_values is not supported for pipeline models at the moment.") past_key_values = None @@ -251,7 +242,7 @@ class FalconPipelineForwards: # Compute alibi tensor: check build_alibi_tensor documentation alibi = None past_key_values_length = 0 - + batch_size, seq_length, _ = hidden_states.shape if self.use_alibi: mask = ( @@ -262,7 +253,7 @@ class FalconPipelineForwards: else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) - + if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device @@ -280,7 +271,7 @@ class FalconPipelineForwards: # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - + # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -319,7 +310,7 @@ class FalconPipelineForwards: hidden_states = outputs[0] if use_cache is True: - next_decoder_cache = outputs[1] + outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -332,7 +323,7 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): - + if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None