This commit is contained in:
wangbluo 2025-05-08 18:03:53 +08:00
parent 4eced5cf8a
commit fe94d73f6b

View File

@ -3,10 +3,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
@ -63,11 +59,12 @@ class BertPipelineForwards:
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
): ):
# TODO(jianghai): add explaination of the output here.
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder. the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
@ -133,43 +130,13 @@ class BertPipelineForwards:
# past_key_values_length # past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
use_sdpa_attention_masks = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
)
# Expand the attention mask
if use_sdpa_attention_masks and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask = extended_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None: if self.config.is_decoder and encoder_hidden_states is not None:
@ -177,13 +144,6 @@ class BertPipelineForwards:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
@ -204,8 +164,7 @@ class BertPipelineForwards:
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
) )
print("hidden_states:", hidden_states.shape)
print("bert_model_forward hidden_states:", hidden_states.shape)
# inherit from bert_layer,this should be changed when we add the feature to record hidden_states # inherit from bert_layer,this should be changed when we add the feature to record hidden_states
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
@ -252,25 +211,30 @@ class BertPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.encoder.gradient_checkpointing and self.encoder.training: if self.encoder.gradient_checkpointing and self.encoder.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask,
head_mask=layer_head_mask, layer_head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask,
head_mask=layer_head_mask, layer_head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask,
past_key_value=past_key_value, past_key_value,
output_attentions=output_attentions, output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
@ -1176,7 +1140,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length, past_key_values_length=past_key_values_length,
) )
print("embedding_output:", embedding_output.shape)
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward( embedding_output = split_forward_gather_backward(
@ -1185,7 +1149,6 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication, fp8_communication=shard_config.fp8_communication,
) )
print("after split_forward_gather_backward embedding_output:", embedding_output.shape)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward( encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, encoder_hidden_states,