upgrade_t

This commit is contained in:
wangbluo 2025-05-15 14:31:24 +08:00
parent 46ed5d856b
commit 2223b64931

View File

@ -17,7 +17,7 @@ from transformers.models.t5.modeling_t5 import (
T5Model, T5Model,
T5Stack, T5Stack,
) )
from transformers.utils import logging from transformers.utils import is_torchdynamo_compiling, logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -43,6 +43,7 @@ class T5PipelineForwards:
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position=None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
position_bias: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None,
@ -68,15 +69,6 @@ class T5PipelineForwards:
if use_cache: if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False use_cache = False
if use_cache is True:
if not in_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
stage = stage_manager.stage stage = stage_manager.stage
in_decoder = self.is_decoder in_decoder = self.is_decoder
@ -121,19 +113,30 @@ class T5PipelineForwards:
batch_size, seq_length = input_shape[0], input_shape[1] batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device device = hidden_states.device
# required mask seq length can be calculated via length of past # v4.51.3 transformers past_key_values_length calculation
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
# initialize past_key_values with `None` if past does not exist if attention_mask is None and not is_torchdynamo_compiling():
if past_key_values is None: # required mask seq length can be calculated via length of past cache
past_key_values = [None] * len(self.block) mask_seq_length = past_key_values_length + seq_length
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device) attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min
else:
causal_mask = None
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@ -149,16 +152,16 @@ class T5PipelineForwards:
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
# Going through held blocks. # Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx): for i in range(start_idx, end_idx):
past_key_value = past_key_values[i]
layer_module = self.block[i] layer_module = self.block[i]
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
@ -168,7 +171,7 @@ class T5PipelineForwards:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@ -178,20 +181,24 @@ class T5PipelineForwards:
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
return_dict,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@ -199,30 +206,31 @@ class T5PipelineForwards:
if use_cache is False or use_cache is None: if use_cache is False or use_cache is None:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights) # (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
if in_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache: if output_attentions:
present_key_value_states = present_key_value_states + (present_key_value_state,) all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
# last layer # last layer
if at_last_stage: if at_last_stage:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
next_cache = None
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@ -231,7 +239,7 @@ class T5PipelineForwards:
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
@ -805,6 +813,7 @@ def get_T5_layer_self_attention_forward():
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@ -815,6 +824,7 @@ def get_T5_layer_self_attention_forward():
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them