From 2223b649313e2292678d8a08e926a8e1fb29656c Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 15 May 2025 14:31:24 +0800 Subject: [PATCH 1/3] upgrade_t --- colossalai/shardformer/modeling/t5.py | 78 +++++++++++++++------------ 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 1b5c03ce4..66e4184ab 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -17,7 +17,7 @@ from transformers.models.t5.modeling_t5 import ( T5Model, T5Stack, ) -from transformers.utils import logging +from transformers.utils import is_torchdynamo_compiling, logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -43,6 +43,7 @@ class T5PipelineForwards: output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, + cache_position=None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, @@ -68,15 +69,6 @@ class T5PipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if use_cache is True: - if not in_decoder: - raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False stage = stage_manager.stage in_decoder = self.is_decoder @@ -121,19 +113,30 @@ class T5PipelineForwards: batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + # v4.51.3 transformers past_key_values_length calculation + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=hidden_states.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -149,16 +152,16 @@ class T5PipelineForwards: # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): - past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -168,7 +171,7 @@ class T5PipelineForwards: layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -178,20 +181,24 @@ class T5PipelineForwards: None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -199,30 +206,31 @@ class T5PipelineForwards: if use_cache is False or use_cache is None: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] - - if in_decoder and encoder_hidden_states is not None: + if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) # last layer if at_last_stage: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - + next_cache = None if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -231,7 +239,7 @@ class T5PipelineForwards: ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -805,6 +813,7 @@ def get_T5_layer_self_attention_forward(): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -815,6 +824,7 @@ def get_T5_layer_self_attention_forward(): past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them From e1925b36c40ae69122fb36c741c2d80a965f3405 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 16 May 2025 15:28:04 +0800 Subject: [PATCH 2/3] upgrade_gptj --- colossalai/shardformer/modeling/gptj.py | 125 ++++++++++++------------ 1 file changed, 63 insertions(+), 62 deletions(-) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 51b228712..168e554f9 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -79,7 +80,7 @@ class GPTJPipelineForwards: def gptj_model_forward( self: GPTJModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -89,12 +90,13 @@ class GPTJPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ) -> Union[Dict, Tuple, BaseModelOutputWithPast]: - # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward. + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJModel.forward. # Please refer to original code of transformers for more details. # GPTJ has no cross attention in comparison to GPT2 @@ -118,8 +120,8 @@ class GPTJPipelineForwards: use_cache = False if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: batch_size, seq_length = input_ids.shape input_shape = input_ids.size() @@ -130,17 +132,34 @@ class GPTJPipelineForwards: batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, seq_length) else: if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device + + if stage_manager.is_first_stage(): + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + seq_length = hidden_states.shape[1] + if cache_position is None: + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -148,33 +167,9 @@ class GPTJPipelineForwards: # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - # position id to be assigned not just for the first stage for attn input - if position_ids is None: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - if stage_manager.is_first_stage(): - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + output_shape = (-1, seq_length, hidden_states.size(-1)) - output_shape = input_shape + (hidden_states.size(-1),) - - attention_mask = _get_attention_mask( - self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2 - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None @@ -207,29 +202,26 @@ class GPTJPipelineForwards: block.__call__, hidden_states, None, - attention_mask, + causal_mask, position_ids, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states=hidden_states, - layer_past=None, - attention_mask=attention_mask, + layer_past=past_key_values, + attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: @@ -248,22 +240,17 @@ class GPTJPipelineForwards: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): if not return_dict: return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] - if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -275,7 +262,7 @@ class GPTJPipelineForwards: def gptj_causallm_model_forward( self: GPTJForCausalLM, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -286,6 +273,7 @@ class GPTJPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -315,6 +303,7 @@ class GPTJPipelineForwards: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -326,18 +315,28 @@ class GPTJPipelineForwards: return {"hidden_states": transformer_outputs["hidden_states"]} hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # v4.51.3 tranformers loss calculation + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float32) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function( + lm_logits, + labels, + vocab_size=self.config.vocab_size, + ) loss = loss.to(hidden_states.dtype) @@ -357,7 +356,7 @@ class GPTJPipelineForwards: def gptj_for_sequence_classification_forward( self: GPTJForSequenceClassification, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -379,7 +378,7 @@ class GPTJPipelineForwards: config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. # Please refer to original code of transformers for more details. """ logger = logging.get_logger(__name__) @@ -581,6 +580,8 @@ def get_gptj_flash_attention_forward(): Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + # This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJAttention.forward. + # Please refer to original code of transformers for more details. assert head_mask is None, "head_mask is not supported for FlashAttention" query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) From 4e49f056d048ab84570163e6cdb370939d860f22 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 16 May 2025 15:32:16 +0800 Subject: [PATCH 3/3] fix --- colossalai/shardformer/modeling/t5.py | 78 ++++++++++++--------------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 66e4184ab..1b5c03ce4 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -17,7 +17,7 @@ from transformers.models.t5.modeling_t5 import ( T5Model, T5Stack, ) -from transformers.utils import is_torchdynamo_compiling, logging +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -43,7 +43,6 @@ class T5PipelineForwards: output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, - cache_position=None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, @@ -69,6 +68,15 @@ class T5PipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False + if use_cache is True: + if not in_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False stage = stage_manager.stage in_decoder = self.is_decoder @@ -113,30 +121,19 @@ class T5PipelineForwards: batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - # v4.51.3 transformers past_key_values_length calculation - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - if cache_position is None: - cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past cache - mask_seq_length = past_key_values_length + seq_length + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - if self.config.is_decoder: - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, - output_attentions, - ) - elif attention_mask is not None: - causal_mask = attention_mask[:, None, None, :] - causal_mask = causal_mask.to(dtype=hidden_states.dtype) - causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min - else: - causal_mask = None + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -152,16 +149,16 @@ class T5PipelineForwards: # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): + past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -171,7 +168,7 @@ class T5PipelineForwards: layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - causal_mask, + extended_attention_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -181,24 +178,20 @@ class T5PipelineForwards: None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, - return_dict, - cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, + attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -206,31 +199,30 @@ class T5PipelineForwards: if use_cache is False or use_cache is None: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, next_decoder_cache = layer_outputs[:2] + hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + if in_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) # last layer if at_last_stage: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - next_cache = None + if not return_dict: return tuple( v for v in [ hidden_states, - next_cache, + present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, @@ -239,7 +231,7 @@ class T5PipelineForwards: ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, @@ -813,7 +805,6 @@ def get_T5_layer_self_attention_forward(): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, - cache_position=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -824,7 +815,6 @@ def get_T5_layer_self_attention_forward(): past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, - cache_position=cache_position, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them