mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -21,16 +21,17 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class OPTPipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of OPT models
|
||||
under pipeline setting.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
from transformers.models.opt.modeling_opt import _make_causal_mask
|
||||
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
@@ -42,10 +43,12 @@ class OPTPipelineForwards:
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
|
||||
tgt_len=input_shape[-1]).to(device)
|
||||
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
|
||||
combined_attention_mask)
|
||||
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
|
||||
device
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
@@ -79,17 +82,19 @@ class OPTPipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
'''
|
||||
"""
|
||||
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
|
||||
'''
|
||||
"""
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -133,10 +138,12 @@ class OPTPipelineForwards:
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)")
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
|
||||
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
|
||||
device, past_key_values_length)
|
||||
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, _dtype, device, past_key_values_length
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
|
||||
@@ -145,21 +152,22 @@ class OPTPipelineForwards:
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if past_key_values:
|
||||
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
||||
past_key_values = None
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
@@ -173,7 +181,8 @@ class OPTPipelineForwards:
|
||||
if attn_mask.size()[0] != (len(decoder.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}.")
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
|
||||
@@ -195,7 +204,6 @@ class OPTPipelineForwards:
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
@@ -250,7 +258,7 @@ class OPTPipelineForwards:
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
else:
|
||||
return {'hidden_states': hidden_states}
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def opt_for_causal_lm_forward(
|
||||
@@ -275,8 +283,9 @@ class OPTPipelineForwards:
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
@@ -319,8 +328,8 @@ class OPTPipelineForwards:
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def opt_for_sequence_classification_forward(
|
||||
@@ -348,19 +357,21 @@ class OPTPipelineForwards:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
transformer_outputs = OPTPipelineForwards.opt_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -377,7 +388,8 @@ class OPTPipelineForwards:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
@@ -416,8 +428,8 @@ class OPTPipelineForwards:
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
@staticmethod
|
||||
def opt_for_question_answering_forward(
|
||||
@@ -443,19 +455,21 @@ class OPTPipelineForwards:
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
transformer_outputs = OPTPipelineForwards.opt_model_forward(
|
||||
self.model,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -493,12 +507,11 @@ class OPTPipelineForwards:
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
hidden_states = transformer_outputs.get("hidden_states")
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_opt_flash_attention_forward():
|
||||
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
@@ -555,27 +568,27 @@ def get_opt_flash_attention_forward():
|
||||
src_len = key_states.size(1)
|
||||
if layer_head_mask != None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}")
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scale=self.scaling)
|
||||
attn_output = attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type)
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
@@ -584,7 +597,6 @@ def get_opt_flash_attention_forward():
|
||||
|
||||
|
||||
def get_jit_fused_opt_decoder_layer_forward():
|
||||
|
||||
from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
||||
|
||||
def forward(
|
||||
|
Reference in New Issue
Block a user