diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 66e4184ab..1b5c03ce4 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -17,7 +17,7 @@ from transformers.models.t5.modeling_t5 import ( T5Model, T5Stack, ) -from transformers.utils import is_torchdynamo_compiling, logging +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -43,7 +43,6 @@ class T5PipelineForwards: output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, - cache_position=None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, @@ -69,6 +68,15 @@ class T5PipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False + if use_cache is True: + if not in_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False stage = stage_manager.stage in_decoder = self.is_decoder @@ -113,30 +121,19 @@ class T5PipelineForwards: batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - # v4.51.3 transformers past_key_values_length calculation - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - if cache_position is None: - cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past cache - mask_seq_length = past_key_values_length + seq_length + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - if self.config.is_decoder: - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, - output_attentions, - ) - elif attention_mask is not None: - causal_mask = attention_mask[:, None, None, :] - causal_mask = causal_mask.to(dtype=hidden_states.dtype) - causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min - else: - causal_mask = None + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -152,16 +149,16 @@ class T5PipelineForwards: # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): + past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -171,7 +168,7 @@ class T5PipelineForwards: layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - causal_mask, + extended_attention_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -181,24 +178,20 @@ class T5PipelineForwards: None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, - return_dict, - cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, + attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -206,31 +199,30 @@ class T5PipelineForwards: if use_cache is False or use_cache is None: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + if in_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) # last layer if at_last_stage: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - next_cache = None + if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, @@ -239,7 +231,7 @@ class T5PipelineForwards: ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -813,7 +805,6 @@ def get_T5_layer_self_attention_forward(): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, - cache_position=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -824,7 +815,6 @@ def get_T5_layer_self_attention_forward(): past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, - cache_position=cache_position, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them