[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: GPT2Model,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor],
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
batch_size, seq_len = hidden_states.shape[:2]
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if shard_config.enable_flash_attention:
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
dtype=hidden_states.dtype,
dtype2=encoder_hidden_states.dtype,
q_padding_mask=attention_mask,
kv_padding_mask=encoder_attention_mask,
)
else:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
if shard_config.enable_flash_attention:
encoder_attention_mask = {"attention_mask": None}
else:
encoder_attention_mask = None
# GPT2Attention mask.
past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if shard_config.enable_flash_attention:
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
return attention_mask, encoder_attention_mask
class GPT2PipelineForwards:
"""
@@ -83,10 +153,10 @@ class GPT2PipelineForwards:
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -99,38 +169,7 @@ class GPT2PipelineForwards:
input_shape = hidden_states.size()[:-1]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
batch_size = hidden_states.shape[0]
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
hidden_states.shape[0]
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -156,6 +195,16 @@ class GPT2PipelineForwards:
output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
@@ -171,7 +220,9 @@ class GPT2PipelineForwards:
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
# Going through held blocks.
@@ -180,7 +231,7 @@ class GPT2PipelineForwards:
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
@@ -229,7 +280,9 @@ class GPT2PipelineForwards:
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
if stage_manager.is_last_stage():
@@ -245,7 +298,13 @@ class GPT2PipelineForwards:
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
@@ -333,7 +392,9 @@ class GPT2PipelineForwards:
shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
)
else:
loss = loss_fct(shift_logits, shift_labels)
@@ -733,27 +794,18 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def forward(
self: GPT2Attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
attention_mask: Optional[dict] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[dict] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
assert head_mask is None, "FlashAttention does not support head_mask"
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
@@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward():
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = split_heads(query, self.num_heads, self.head_dim)
key = split_heads(key, self.num_heads, self.head_dim)
value = split_heads(value, self.num_heads, self.head_dim)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
@@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward():
else:
present = None
if not self.is_cross_attention:
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
scale = value.size(-1) ** -0.5
scale = 1.0
if self.scale_attn_weights:
scale /= value.size(-1) ** 0.5
if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1))
# use coloattention
if not hasattr(self, "attention"):
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
scale /= float(self.layer_idx + 1)
dropout_p = self.attn_dropout.p if self.training else 0.0
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
@@ -813,6 +849,195 @@ def get_gpt2_flash_attention_forward():
return forward
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
def forward(
self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
@@ -842,10 +1067,10 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -862,40 +1087,14 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
@@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
hidden_states = self.ln_f(hidden_states)
@@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)