mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -17,10 +17,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class T5PipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of
|
||||
T5 models under pipeline setting.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def t5_stack_forward(
|
||||
@@ -44,7 +44,6 @@ class T5PipelineForwards:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
@@ -52,16 +51,16 @@ class T5PipelineForwards:
|
||||
|
||||
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
if use_cache is True:
|
||||
if not in_decoder:
|
||||
@@ -69,7 +68,8 @@ class T5PipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
stage = stage_manager.stage
|
||||
@@ -97,7 +97,8 @@ class T5PipelineForwards:
|
||||
else:
|
||||
err_msg_prefix = "decoder_" if in_decoder else ""
|
||||
raise ValueError(
|
||||
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
|
||||
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
if self.embed_tokens is None:
|
||||
raise ValueError("You have to initialize the model with valid token embeddings")
|
||||
@@ -108,7 +109,8 @@ class T5PipelineForwards:
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError(
|
||||
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
|
||||
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
|
||||
)
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
@@ -153,7 +155,6 @@ class T5PipelineForwards:
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
|
||||
past_key_value = past_key_values[i]
|
||||
layer_module = self.block[i]
|
||||
layer_head_mask = head_mask[i]
|
||||
@@ -163,7 +164,6 @@ class T5PipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return tuple(module(*inputs, use_cache, output_attentions))
|
||||
|
||||
@@ -179,7 +179,7 @@ class T5PipelineForwards:
|
||||
encoder_decoder_position_bias,
|
||||
layer_head_mask,
|
||||
cross_attn_layer_head_mask,
|
||||
None, # past_key_value is always None with gradient checkpointing
|
||||
None, # past_key_value is always None with gradient checkpointing
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
@@ -220,13 +220,17 @@ class T5PipelineForwards:
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
present_key_value_states,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
all_cross_attentions,
|
||||
] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
present_key_value_states,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_value_states,
|
||||
@@ -236,10 +240,10 @@ class T5PipelineForwards:
|
||||
)
|
||||
else:
|
||||
return {
|
||||
'hidden_states': hidden_states,
|
||||
'position_bias': position_bias,
|
||||
'encoder_decoder_position_bias': encoder_decoder_position_bias,
|
||||
'backward_tensor_keys': ['hidden_states']
|
||||
"hidden_states": hidden_states,
|
||||
"position_bias": position_bias,
|
||||
"encoder_decoder_position_bias": encoder_decoder_position_bias,
|
||||
"backward_tensor_keys": ["hidden_states"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -269,7 +273,6 @@ class T5PipelineForwards:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
@@ -287,16 +290,16 @@ class T5PipelineForwards:
|
||||
|
||||
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
@@ -322,10 +325,11 @@ class T5PipelineForwards:
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
# last stage of encoder
|
||||
return {'encoder_hidden_states': encoder_outputs[0]}
|
||||
return {"encoder_hidden_states": encoder_outputs[0]}
|
||||
else:
|
||||
return encoder_outputs
|
||||
|
||||
@@ -360,23 +364,26 @@ class T5PipelineForwards:
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
|
||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||
if not at_last_decoder_stage:
|
||||
# encoder_hidden_states should be passed to the next stage
|
||||
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||
decoder_outputs["encoder_hidden_states"] = encoder_hidden_states
|
||||
return decoder_outputs
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_hidden_states
|
||||
else:
|
||||
return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_hidden_states)
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_hidden_states,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def t5_for_conditional_generation_forward(
|
||||
@@ -406,7 +413,6 @@ class T5PipelineForwards:
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
|
||||
# This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
|
||||
@@ -424,16 +430,16 @@ class T5PipelineForwards:
|
||||
|
||||
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
@@ -460,10 +466,11 @@ class T5PipelineForwards:
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
# last stage of encoder
|
||||
return {'encoder_hidden_states': encoder_outputs[0]}
|
||||
return {"encoder_hidden_states": encoder_outputs[0]}
|
||||
else:
|
||||
return encoder_outputs
|
||||
|
||||
@@ -502,12 +509,13 @@ class T5PipelineForwards:
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
|
||||
# Directly return outputs of overloaded T5Stack forward if not at last stage.
|
||||
if not at_last_decoder_stage:
|
||||
# encoder_hidden_states should be passed to the next stage
|
||||
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||
decoder_outputs["encoder_hidden_states"] = encoder_hidden_states
|
||||
return decoder_outputs
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
@@ -530,13 +538,15 @@ class T5PipelineForwards:
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_hidden_states)
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_hidden_states,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def t5_encoder_model_forward(
|
||||
@@ -562,26 +572,27 @@ class T5PipelineForwards:
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = T5PipelineForwards.t5_stack_forward(self.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_t5_flash_attention_forward():
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
@@ -655,19 +666,21 @@ def get_t5_flash_attention_forward():
|
||||
return hidden_states
|
||||
|
||||
# get query states
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
||||
|
||||
# get key/value states
|
||||
key_states = project(hidden_states, self.k, key_value_states,
|
||||
past_key_value[0] if past_key_value is not None else None)
|
||||
value_states = project(hidden_states, self.v, key_value_states,
|
||||
past_key_value[1] if past_key_value is not None else None)
|
||||
key_states = project(
|
||||
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
|
||||
)
|
||||
value_states = project(
|
||||
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
|
||||
)
|
||||
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype)
|
||||
position_bias = torch.zeros(
|
||||
(1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype
|
||||
)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
position_bias.requires_grad = True
|
||||
else:
|
||||
@@ -676,10 +689,10 @@ def get_t5_flash_attention_forward():
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -hidden_states.size(1):, :]
|
||||
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
if self.pruned_heads:
|
||||
mask = torch.ones(position_bias.shape[1])
|
||||
@@ -689,12 +702,9 @@ def get_t5_flash_attention_forward():
|
||||
position_bias_masked = position_bias
|
||||
|
||||
position_bias_masked = position_bias_masked.contiguous()
|
||||
attn_output = me_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=position_bias_masked,
|
||||
p=self.dropout,
|
||||
scale=1.0)
|
||||
attn_output = me_attention(
|
||||
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0
|
||||
)
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
@@ -708,7 +718,6 @@ def get_t5_flash_attention_forward():
|
||||
|
||||
|
||||
def get_jit_fused_T5_layer_ff_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerFF
|
||||
|
||||
def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -721,7 +730,6 @@ def get_jit_fused_T5_layer_ff_forward():
|
||||
|
||||
|
||||
def get_T5_layer_self_attention_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
|
||||
|
||||
def forward(
|
||||
@@ -745,14 +753,13 @@ def get_T5_layer_self_attention_forward():
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_T5_layer_cross_attention_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
|
||||
|
||||
def forward(
|
||||
@@ -780,7 +787,7 @@ def get_T5_layer_cross_attention_forward():
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user