This commit is contained in:
Wang Binluo 2025-04-28 10:19:28 +00:00 committed by GitHub
commit c34d2bc4e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 87 deletions

View File

@ -1,16 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -110,19 +103,18 @@ def get_tp_falcon_decoder_layer_forward():
alibi: Optional[torch.Tensor], alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs, **kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states) attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states)
else: else:
@ -138,7 +130,8 @@ def get_tp_falcon_decoder_layer_forward():
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attention_output = attn_outputs[0] attention_output = attn_outputs[0]
@ -152,6 +145,13 @@ def get_tp_falcon_decoder_layer_forward():
) )
mlp_layernorm_out = self.post_attention_layernorm(residual) mlp_layernorm_out = self.post_attention_layernorm(residual)
if (
self.config.new_decoder_architecture
and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out
outputs = attn_outputs[1:] outputs = attn_outputs[1:]
# MLP. # MLP.
@ -167,7 +167,7 @@ def get_tp_falcon_decoder_layer_forward():
else: else:
outputs = (output,) + outputs[1:] outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions return outputs # hidden_states, past_kv, attentions
return forward return forward
@ -190,6 +190,7 @@ class FalconPipelineForwards:
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
@ -206,9 +207,8 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
if past_key_values is not None: logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
logger.warning_once("past_key_values is not supported for pipeline models at the moment.") past_key_values = None
past_key_values = None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -221,7 +221,7 @@ class FalconPipelineForwards:
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -229,12 +229,9 @@ class FalconPipelineForwards:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
@ -243,10 +240,10 @@ class FalconPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
alibi = None
past_key_values_length = 0 past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-2]
batch_size, seq_length, _ = hidden_states.shape
if self.use_alibi: if self.use_alibi:
mask = ( mask = (
torch.ones( torch.ones(
@ -256,73 +253,30 @@ class FalconPipelineForwards:
else attention_mask else attention_mask
) )
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if self._use_flash_attention_2: if cache_position is None:
# 2d mask is passed through the layers cache_position = torch.arange(
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N # attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate( for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -331,28 +285,32 @@ class FalconPipelineForwards:
block.__call__, block.__call__,
hidden_states, hidden_states,
alibi, alibi,
attention_mask, causal_mask,
position_ids, position_ids,
head_mask[i], head_mask[i],
layer_past, past_key_values,
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
position_embeddings,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=past_key_values,
attention_mask=attention_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
presents = presents + (outputs[1],) outputs[1]
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -365,6 +323,7 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None

View File

@ -246,6 +246,7 @@ class FalconPolicy(Policy):
module = self.model.transformer module = self.model.transformer
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.h)) layers_per_stage = stage_manager.distribute_layers(len(module.h))