mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[shardformer/sequence parallel] support gpt2 seq parallel with pp/dp/tp (#4460)
* support gpt2 seq parallel with pp/dp/tp * fix a bug when waiting for stream done * delete unused gpt2_seq file
This commit is contained in:
@@ -21,6 +21,8 @@ from transformers.models.gpt2.modeling_gpt2 import (
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
@@ -47,7 +49,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
@@ -159,6 +162,13 @@ class GPT2PipelineForwards:
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i in range(start_idx, end_idx):
|
||||
@@ -212,6 +222,12 @@ class GPT2PipelineForwards:
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
@@ -257,7 +273,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
@@ -285,7 +302,8 @@ class GPT2PipelineForwards:
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
@@ -335,7 +353,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
||||
r"""
|
||||
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
||||
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
||||
@@ -367,7 +386,8 @@ class GPT2PipelineForwards:
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
@@ -421,7 +441,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
@@ -449,7 +470,8 @@ class GPT2PipelineForwards:
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
@@ -508,7 +530,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@@ -534,7 +557,8 @@ class GPT2PipelineForwards:
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
@@ -578,7 +602,8 @@ class GPT2PipelineForwards:
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@@ -613,7 +638,8 @@ class GPT2PipelineForwards:
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
@@ -696,7 +722,6 @@ def get_gpt2_flash_attention_forward():
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
_, tgt_len, _ = hidden_states.size()
|
||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
@@ -753,3 +778,210 @@ def get_gpt2_flash_attention_forward():
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user