Merge pull request #6283 from wangbluo/upgrade_falcon

[shardformer] Upgrade transformers version: falcon model
This commit is contained in:
Hanks 2025-05-14 15:05:31 +08:00 committed by GitHub
commit 5374601741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,3 @@
import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -6,11 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -114,6 +108,10 @@ def get_tp_falcon_decoder_layer_forward():
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # Add cache_position and position_embeddings args for v4.51.3 transformers
**kwargs, **kwargs,
): ):
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
@ -122,7 +120,8 @@ def get_tp_falcon_decoder_layer_forward():
) )
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture: # same as v4.51.3 transformers
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states) attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states)
else: else:
@ -138,7 +137,8 @@ def get_tp_falcon_decoder_layer_forward():
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attention_output = attn_outputs[0] attention_output = attn_outputs[0]
@ -151,6 +151,13 @@ def get_tp_falcon_decoder_layer_forward():
attention_output, residual, self.config.attention_dropout, training=self.training attention_output, residual, self.config.attention_dropout, training=self.training
) )
mlp_layernorm_out = self.post_attention_layernorm(residual) mlp_layernorm_out = self.post_attention_layernorm(residual)
# v4.51.3 transformers mlp
if (
self.config.new_decoder_architecture
and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out
outputs = attn_outputs[1:] outputs = attn_outputs[1:]
@ -190,11 +197,14 @@ class FalconPipelineForwards:
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
# Add cache_position and position_embeddings args for v4.51.3 transformers
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@ -206,9 +216,8 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
if past_key_values is not None: logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
logger.warning_once("past_key_values is not supported for pipeline models at the moment.") past_key_values = None
past_key_values = None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -229,9 +238,6 @@ class FalconPipelineForwards:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
@ -243,10 +249,11 @@ class FalconPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
# alibi calculation is same as v4.51.3 transformers.
alibi = None
past_key_values_length = 0 past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-2]
batch_size, seq_length, _ = hidden_states.shape
if self.use_alibi: if self.use_alibi:
mask = ( mask = (
torch.ones( torch.ones(
@ -256,73 +263,32 @@ class FalconPipelineForwards:
else attention_mask else attention_mask
) )
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if self._use_flash_attention_2: if cache_position is None:
# 2d mask is passed through the layers cache_position = torch.arange(
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# use new version of causal mask construction.
# In v4.51.3 version, sdpa, egaer and flash attention are merged into one class.
causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N # attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# v4.51.3 create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate( # keep past_key_values arg same with v4.51.3 transformers
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -331,28 +297,32 @@ class FalconPipelineForwards:
block.__call__, block.__call__,
hidden_states, hidden_states,
alibi, alibi,
attention_mask, causal_mask,
position_ids, position_ids,
head_mask[i], head_mask[i],
layer_past, past_key_values,
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
position_embeddings,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=past_key_values,
attention_mask=attention_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
presents = presents + (outputs[1],) outputs[1]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -365,6 +335,7 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None