[Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Wenxuan Tan
2024-09-10 12:06:50 +08:00
committed by GitHub
parent b3db1058ec
commit 8fd25d6e09
25 changed files with 527 additions and 1173 deletions

View File

@@ -21,8 +21,9 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.layer import ColoAttention, RingAttention
from colossalai.shardformer.layer._operation import gather_sp_output, split_forward_gather_backward
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
from ..layer import dist_cross_entropy
@@ -39,10 +40,16 @@ def _get_attention_mask(
encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor],
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
batch_size, seq_len = hidden_states.shape[:2]
# Received input is already split for non-first pipeline stages,
# but attn mask isn't
batch_size = hidden_states.size(0)
seq_len = attention_mask.size(-1)
sp_mode = shard_config.sequence_parallelism_mode
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
assert not sp_mode == "ring_attn", "Ring Attention only supports decoder-only."
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if shard_config.enable_flash_attention:
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
@@ -62,6 +69,7 @@ def _get_attention_mask(
encoder_attention_mask = {"attention_mask": None}
else:
encoder_attention_mask = None
# GPT2Attention mask.
past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None:
@@ -69,6 +77,7 @@ def _get_attention_mask(
if shard_config.enable_flash_attention:
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype,
@@ -123,6 +132,7 @@ class GPT2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
force_sp_gather: Optional[bool] = True,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
@@ -146,16 +156,15 @@ class GPT2PipelineForwards:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if stage_manager.is_first_stage():
disable_pp = stage_manager is None
if disable_pp or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -176,7 +185,7 @@ class GPT2PipelineForwards:
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
if disable_pp or stage_manager.is_first_stage():
if position_ids is None:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
@@ -190,9 +199,7 @@ class GPT2PipelineForwards:
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
attn_kwargs, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
@@ -215,23 +222,43 @@ class GPT2PipelineForwards:
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if disable_pp or stage_manager.is_first_stage():
# Ring Attention's special zigzag batch processing
if sp_mode == "ring_attn":
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if not attention_mask.bool().all():
hidden_states, attn_kwargs, position_ids = RingAttention.prepare_varlen_batch(
attention_mask, sp_group, hidden_states, position_ids
)
else:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
# Other sp modes
else:
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
)
elif sp_mode == "ring_attn":
# Later stages already received split hidden states
_, attn_kwargs, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group)
del attention_mask
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
if disable_pp:
start_idx, end_idx = 0, len(self.h)
else:
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
block = self.h[i]
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if torch.is_tensor(attn_kwargs):
attn_kwargs = attn_kwargs.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
@@ -242,7 +269,7 @@ class GPT2PipelineForwards:
block.__call__,
hidden_states,
None,
attention_mask,
attn_kwargs,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
@@ -253,7 +280,7 @@ class GPT2PipelineForwards:
outputs = block(
hidden_states,
layer_past=None,
attention_mask=attention_mask,
attention_mask=attn_kwargs,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
@@ -270,26 +297,25 @@ class GPT2PipelineForwards:
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# When sequence parallelism is done, gather the output tensor in forward and split it in backward
gather_output = (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode)
if disable_pp or stage_manager.is_last_stage():
if gather_output:
hidden_states = gather_sp_output(hidden_states, shard_config)
if stage_manager.is_last_stage():
# gather_sp_output could've changed seq length.
input_shape = (*input_shape[:-1], hidden_states.size(-2))
output_shape = input_shape + (hidden_states.size(-1),)
if disable_pp or stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
if disable_pp or stage_manager.is_last_stage():
if not return_dict:
return tuple(
v
@@ -366,17 +392,29 @@ class GPT2PipelineForwards:
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
force_sp_gather=False,
)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
disable_pp = stage_manager is None
if (not disable_pp) and (not stage_manager.is_last_stage()):
return {"hidden_states": outputs["hidden_states"]}
hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if shard_config.sequence_parallelism_mode == "ring_attn":
# Split labels in a zigzag fashion too
sp_group = shard_config.sequence_parallel_process_group
if not attention_mask.bool().all():
# [B, max_seqlen // sp_size]
labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True)
else:
labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True)
if labels is not None:
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
@@ -770,7 +808,7 @@ class GPT2PipelineForwards:
)
def get_gpt2_flash_attention_forward():
def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
def forward(
@@ -817,7 +855,22 @@ def get_gpt2_flash_attention_forward():
if self.scale_attn_by_inverse_layer_idx:
scale /= float(self.layer_idx + 1)
dropout_p = self.attn_dropout.p if self.training else 0.0
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
@@ -828,466 +881,6 @@ def get_gpt2_flash_attention_forward():
return forward
def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
def forward(
self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger = logging.get_logger(__name__)
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import GPT2LMHeadModel
def forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
return forward
def get_jit_fused_gpt2_mlp_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP