mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_attention_mask(
|
||||
self: GPT2Model,
|
||||
shard_config: ShardConfig,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
encoder_attention_mask: Optional[torch.FloatTensor],
|
||||
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
|
||||
batch_size, seq_len = hidden_states.shape[:2]
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
if shard_config.enable_flash_attention:
|
||||
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
|
||||
dtype=hidden_states.dtype,
|
||||
dtype2=encoder_hidden_states.dtype,
|
||||
q_padding_mask=attention_mask,
|
||||
kv_padding_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
if shard_config.enable_flash_attention:
|
||||
encoder_attention_mask = {"attention_mask": None}
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
# GPT2Attention mask.
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None and past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
if shard_config.enable_flash_attention:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_len, seq_len + past_key_values_length),
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
elif attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
return attention_mask, encoder_attention_mask
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
"""
|
||||
@@ -83,10 +153,10 @@ class GPT2PipelineForwards:
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
@@ -99,38 +169,7 @@ class GPT2PipelineForwards:
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
device = hidden_states.device
|
||||
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
hidden_states.shape[0]
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -156,6 +195,16 @@ class GPT2PipelineForwards:
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
@@ -171,7 +220,9 @@ class GPT2PipelineForwards:
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
@@ -180,7 +231,7 @@ class GPT2PipelineForwards:
|
||||
block = self.h[i]
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
@@ -229,7 +280,9 @@ class GPT2PipelineForwards:
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
@@ -245,7 +298,13 @@ class GPT2PipelineForwards:
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
@@ -333,7 +392,9 @@ class GPT2PipelineForwards:
|
||||
shift_labels = shift_labels.view(-1)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
@@ -733,27 +794,18 @@ class GPT2PipelineForwards:
|
||||
def get_gpt2_flash_attention_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def split_heads(tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor
|
||||
|
||||
def forward(
|
||||
self: GPT2Attention,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[dict] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[dict] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
assert head_mask is None, "FlashAttention does not support head_mask"
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
@@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward():
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = split_heads(query, self.num_heads, self.head_dim)
|
||||
key = split_heads(key, self.num_heads, self.head_dim)
|
||||
value = split_heads(value, self.num_heads, self.head_dim)
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
@@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward():
|
||||
else:
|
||||
present = None
|
||||
|
||||
if not self.is_cross_attention:
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
flash_attention_mask = None
|
||||
if attention_mask != None:
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
if not torch.all(flash_attention_mask):
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
|
||||
scale = value.size(-1) ** -0.5
|
||||
scale = 1.0
|
||||
if self.scale_attn_weights:
|
||||
scale /= value.size(-1) ** 0.5
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale = scale * (1 / float(self.layer_idx + 1))
|
||||
|
||||
# use coloattention
|
||||
if not hasattr(self, "attention"):
|
||||
self.attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
|
||||
)
|
||||
|
||||
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
|
||||
|
||||
scale /= float(self.layer_idx + 1)
|
||||
dropout_p = self.attn_dropout.p if self.training else 0.0
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
outputs = (attn_output, present, None)
|
||||
@@ -813,6 +849,195 @@ def get_gpt2_flash_attention_forward():
|
||||
return forward
|
||||
|
||||
|
||||
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self,
|
||||
@@ -842,10 +1067,10 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
@@ -862,40 +1087,14 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
attention_mask, encoder_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
@@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
if torch.is_tensor(attention_mask):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
@@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user