diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 1b5c03ce4..5e119d6fc 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -43,6 +43,7 @@ class T5PipelineForwards: output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, + cache_position=None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, @@ -68,15 +69,6 @@ class T5PipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if use_cache is True: - if not in_decoder: - raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False stage = stage_manager.stage in_decoder = self.is_decoder @@ -122,11 +114,17 @@ class T5PipelineForwards: device = hidden_states.device # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + mask_seq_length = seq_length # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) + + past_key_values_length = 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device + ) if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=device) @@ -145,6 +143,22 @@ class T5PipelineForwards: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=hidden_states.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min + else: + causal_mask = None + # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) @@ -158,7 +172,6 @@ class T5PipelineForwards: start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): - past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -168,7 +181,7 @@ class T5PipelineForwards: layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -178,20 +191,24 @@ class T5PipelineForwards: None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=None, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -669,6 +686,7 @@ def get_t5_flash_attention_forward(): query_length: Optional[int] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). @@ -805,6 +823,7 @@ def get_T5_layer_self_attention_forward(): past_key_value: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -815,6 +834,7 @@ def get_T5_layer_self_attention_forward(): past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them