From 07349e00146d06ece045aa82baa9b5335452b966 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 14 May 2025 10:09:35 +0800 Subject: [PATCH] fix --- colossalai/shardformer/modeling/falcon.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 4f1d0ccd8..d06f8db2c 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -108,11 +108,15 @@ def get_tp_falcon_decoder_layer_forward(): use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers **kwargs, ): + residual = hidden_states + # same as v4.51.3 transformers if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) @@ -143,7 +147,7 @@ def get_tp_falcon_decoder_layer_forward(): attention_output, residual, self.config.attention_dropout, training=self.training ) mlp_layernorm_out = self.post_attention_layernorm(residual) - + # v4.51.3 transformers mlp if ( self.config.new_decoder_architecture and self.config.parallel_attn @@ -241,6 +245,7 @@ class FalconPipelineForwards: all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation + # alibi calculation is same as v4.51.3 transformers. alibi = None past_key_values_length = 0 @@ -274,10 +279,11 @@ class FalconPipelineForwards: # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - # create position embeddings to be shared across the decoder layers + # v4.51.3 create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] + # keep past_key_values arg same with v4.51.3 transformers for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,)