mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
upgrade_gptj
This commit is contained in:
parent
2223b64931
commit
e1925b36c4
@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -79,7 +80,7 @@ class GPTJPipelineForwards:
|
||||
def gptj_model_forward(
|
||||
self: GPTJModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@ -89,12 +90,13 @@ class GPTJPipelineForwards:
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPast]:
|
||||
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward.
|
||||
# This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJModel.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
# GPTJ has no cross attention in comparison to GPT2
|
||||
|
||||
@ -118,8 +120,8 @@ class GPTJPipelineForwards:
|
||||
use_cache = False
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
input_shape = input_ids.size()
|
||||
@ -130,17 +132,34 @@ class GPTJPipelineForwards:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, seq_length)
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
seq_length = hidden_states.shape[1]
|
||||
if cache_position is None:
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, hidden_states, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -148,33 +167,9 @@ class GPTJPipelineForwards:
|
||||
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
# position id to be assigned not just for the first stage for attn input
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
if stage_manager.is_first_stage():
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = (-1, seq_length, hidden_states.size(-1))
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
attention_mask = _get_attention_mask(
|
||||
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
@ -207,29 +202,26 @@ class GPTJPipelineForwards:
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states=hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=attention_mask,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
@ -248,22 +240,17 @@ class GPTJPipelineForwards:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
]
|
||||
if v is not None
|
||||
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
@ -275,7 +262,7 @@ class GPTJPipelineForwards:
|
||||
def gptj_causallm_model_forward(
|
||||
self: GPTJForCausalLM,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@ -286,6 +273,7 @@ class GPTJPipelineForwards:
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
@ -315,6 +303,7 @@ class GPTJPipelineForwards:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
@ -326,18 +315,28 @@ class GPTJPipelineForwards:
|
||||
return {"hidden_states": transformer_outputs["hidden_states"]}
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
# Set device for model parallelism
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
|
||||
# v4.51.3 tranformers loss calculation
|
||||
# make sure sampling in fp16 works correctly and
|
||||
# compute loss in fp32 to match with mesh-tf version
|
||||
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
||||
lm_logits = self.lm_head(hidden_states).to(torch.float32)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss = self.loss_function(
|
||||
lm_logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
)
|
||||
|
||||
loss = loss.to(hidden_states.dtype)
|
||||
|
||||
@ -357,7 +356,7 @@ class GPTJPipelineForwards:
|
||||
def gptj_for_sequence_classification_forward(
|
||||
self: GPTJForSequenceClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@ -379,7 +378,7 @@ class GPTJPipelineForwards:
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
# This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward.
|
||||
# This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -581,6 +580,8 @@ def get_gptj_flash_attention_forward():
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||
]:
|
||||
# This function is modified on the v4.51.3 transformers.models.gptj.modeling_gptj.GPTJAttention.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
assert head_mask is None, "head_mask is not supported for FlashAttention"
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user