mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
fix
This commit is contained in:
parent
cefdfc4125
commit
07349e0014
@ -108,11 +108,15 @@ def get_tp_falcon_decoder_layer_forward():
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# same as v4.51.3 transformers
|
||||
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
|
||||
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||
@ -143,7 +147,7 @@ def get_tp_falcon_decoder_layer_forward():
|
||||
attention_output, residual, self.config.attention_dropout, training=self.training
|
||||
)
|
||||
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||
|
||||
# v4.51.3 transformers mlp
|
||||
if (
|
||||
self.config.new_decoder_architecture
|
||||
and self.config.parallel_attn
|
||||
@ -241,6 +245,7 @@ class FalconPipelineForwards:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
# alibi calculation is same as v4.51.3 transformers.
|
||||
alibi = None
|
||||
past_key_values_length = 0
|
||||
|
||||
@ -274,10 +279,11 @@ class FalconPipelineForwards:
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
# v4.51.3 create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
# keep past_key_values arg same with v4.51.3 transformers
|
||||
for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
Loading…
Reference in New Issue
Block a user