This commit is contained in:
wangbluo 2025-04-28 18:17:12 +08:00
parent 5d167f2148
commit 885210dc27

View File

@ -1,17 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -159,7 +151,7 @@ def get_tp_falcon_decoder_layer_forward():
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out
outputs = attn_outputs[1:]
# MLP.
@ -215,7 +207,6 @@ class FalconPipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
@ -251,7 +242,7 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation
alibi = None
past_key_values_length = 0
batch_size, seq_length, _ = hidden_states.shape
if self.use_alibi:
mask = (
@ -262,7 +253,7 @@ class FalconPipelineForwards:
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
@ -280,7 +271,7 @@ class FalconPipelineForwards:
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@ -319,7 +310,7 @@ class FalconPipelineForwards:
hidden_states = outputs[0]
if use_cache is True:
next_decoder_cache = outputs[1]
outputs[1]
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
@ -332,7 +323,7 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None