mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -26,32 +26,32 @@ from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of GPT2 models
|
||||
under pipeline setting.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gpt2_model_forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
@@ -62,16 +62,16 @@ class GPT2PipelineForwards:
|
||||
# Preprocess passed in arguments
|
||||
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
@@ -115,7 +115,7 @@ class GPT2PipelineForwards:
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
@@ -156,7 +156,8 @@ class GPT2PipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@@ -166,9 +167,9 @@ class GPT2PipelineForwards:
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
@@ -186,7 +187,6 @@ class GPT2PipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
@@ -225,9 +225,9 @@ class GPT2PipelineForwards:
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
if shard_config.enable_sequence_parallelism:
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
@@ -241,8 +241,10 @@ class GPT2PipelineForwards:
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None)
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
@@ -253,62 +255,65 @@ class GPT2PipelineForwards:
|
||||
)
|
||||
else:
|
||||
# always return dict for intermediate stage
|
||||
return {'hidden_states': hidden_states}
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def gpt2_lmhead_model_forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
@@ -337,25 +342,26 @@ class GPT2PipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def gpt2_double_heads_model_forward(
|
||||
self: GPT2DoubleHeadsModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mc_token_ids: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
mc_labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
||||
self: GPT2DoubleHeadsModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mc_token_ids: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
mc_labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
|
||||
r"""
|
||||
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
||||
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
||||
@@ -373,26 +379,28 @@ class GPT2PipelineForwards:
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
@@ -428,22 +436,23 @@ class GPT2PipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_question_answering_forward(
|
||||
self: GPT2ForQuestionAnswering,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
||||
self: GPT2ForQuestionAnswering,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
@@ -459,24 +468,26 @@ class GPT2PipelineForwards:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -516,23 +527,24 @@ class GPT2PipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_token_classification_forward(
|
||||
self: GPT2ForTokenClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
||||
self: GPT2ForTokenClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@@ -544,26 +556,28 @@ class GPT2PipelineForwards:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@@ -588,23 +602,24 @@ class GPT2PipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def gpt2_for_sequence_classification_forward(
|
||||
self: GPT2ForSequenceClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
||||
self: GPT2ForSequenceClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
@@ -613,38 +628,41 @@ class GPT2PipelineForwards:
|
||||
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
"""
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, _ = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, _ = hidden_states.shape[:2]
|
||||
assert (self.config.pad_token_id is not None
|
||||
or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# If not at the last stage, return hidden_states as in GPT2Model
|
||||
if not stage_manager.is_last_stage():
|
||||
return {'hidden_states': outputs['hidden_states']}
|
||||
return {"hidden_states": outputs["hidden_states"]}
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
@@ -658,7 +676,8 @@ class GPT2PipelineForwards:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
@@ -698,7 +717,6 @@ class GPT2PipelineForwards:
|
||||
|
||||
|
||||
def get_gpt2_flash_attention_forward():
|
||||
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
@@ -722,12 +740,12 @@ def get_gpt2_flash_attention_forward():
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
@@ -759,15 +777,14 @@ def get_gpt2_flash_attention_forward():
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
|
||||
scale = value.size(-1)**-0.5
|
||||
scale = value.size(-1) ** -0.5
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale = scale * (1 / float(self.layer_idx + 1))
|
||||
|
||||
# use coloattention
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.attn_dropout.p,
|
||||
scale=scale)
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
|
||||
)
|
||||
|
||||
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
|
||||
|
||||
@@ -781,7 +798,6 @@ def get_gpt2_flash_attention_forward():
|
||||
|
||||
|
||||
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -799,8 +815,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -849,7 +866,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
@@ -886,7 +903,8 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if use_cache:
|
||||
logger = logging.get_logger(__name__)
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
@@ -896,9 +914,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
hidden_states = split_forward_gather_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
@@ -918,7 +936,6 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
@@ -962,9 +979,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
hidden_states = gather_forward_split_backward(hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
@@ -974,8 +991,10 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None)
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
|
Reference in New Issue
Block a user