mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
fix
This commit is contained in:
parent
46ed5d856b
commit
07fa048895
@ -43,6 +43,7 @@ class T5PipelineForwards:
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position=None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
position_bias: Optional[torch.Tensor] = None,
|
position_bias: Optional[torch.Tensor] = None,
|
||||||
@ -68,15 +69,6 @@ class T5PipelineForwards:
|
|||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
use_cache = False
|
use_cache = False
|
||||||
if use_cache is True:
|
|
||||||
if not in_decoder:
|
|
||||||
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
stage = stage_manager.stage
|
stage = stage_manager.stage
|
||||||
in_decoder = self.is_decoder
|
in_decoder = self.is_decoder
|
||||||
@ -122,12 +114,18 @@ class T5PipelineForwards:
|
|||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
# required mask seq length can be calculated via length of past
|
# required mask seq length can be calculated via length of past
|
||||||
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
|
mask_seq_length = seq_length
|
||||||
|
|
||||||
# initialize past_key_values with `None` if past does not exist
|
# initialize past_key_values with `None` if past does not exist
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = [None] * len(self.block)
|
past_key_values = [None] * len(self.block)
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||||
|
|
||||||
@ -146,6 +144,22 @@ class T5PipelineForwards:
|
|||||||
else:
|
else:
|
||||||
encoder_extended_attention_mask = None
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
cache_position,
|
||||||
|
None,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
causal_mask = attention_mask[:, None, None, :]
|
||||||
|
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
|
||||||
|
causal_mask = (1.0 - causal_mask) * torch.finfo(hidden_states.dtype).min
|
||||||
|
else:
|
||||||
|
causal_mask = None
|
||||||
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||||
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
||||||
@ -158,7 +172,6 @@ class T5PipelineForwards:
|
|||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
for i in range(start_idx, end_idx):
|
for i in range(start_idx, end_idx):
|
||||||
past_key_value = past_key_values[i]
|
|
||||||
layer_module = self.block[i]
|
layer_module = self.block[i]
|
||||||
layer_head_mask = head_mask[i]
|
layer_head_mask = head_mask[i]
|
||||||
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
||||||
@ -168,7 +181,7 @@ class T5PipelineForwards:
|
|||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
layer_module.forward,
|
layer_module.forward,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
extended_attention_mask,
|
causal_mask,
|
||||||
position_bias,
|
position_bias,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_extended_attention_mask,
|
encoder_extended_attention_mask,
|
||||||
@ -178,20 +191,24 @@ class T5PipelineForwards:
|
|||||||
None, # past_key_value is always None with gradient checkpointing
|
None, # past_key_value is always None with gradient checkpointing
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
return_dict,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=causal_mask,
|
||||||
position_bias=position_bias,
|
position_bias=position_bias,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||||
layer_head_mask=layer_head_mask,
|
layer_head_mask=layer_head_mask,
|
||||||
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=past_key_value,
|
past_key_value=None,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
# layer_outputs is a tuple with:
|
# layer_outputs is a tuple with:
|
||||||
@ -669,6 +686,7 @@ def get_t5_flash_attention_forward():
|
|||||||
query_length: Optional[int] = None,
|
query_length: Optional[int] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||||
@ -805,6 +823,7 @@ def get_T5_layer_self_attention_forward():
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
normed_hidden_states = self.layer_norm(hidden_states)
|
normed_hidden_states = self.layer_norm(hidden_states)
|
||||||
attention_output = self.SelfAttention(
|
attention_output = self.SelfAttention(
|
||||||
@ -815,6 +834,7 @@ def get_T5_layer_self_attention_forward():
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
||||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||||
|
Loading…
Reference in New Issue
Block a user