mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import (
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_attention_mask(
|
||||
self: OPTModel,
|
||||
shard_config: ShardConfig,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
if shard_config.enable_flash_attention:
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_length, mask_seq_length),
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self.decoder._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class OPTPipelineForwards:
|
||||
@@ -26,46 +57,6 @@ class OPTPipelineForwards:
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
from transformers.models.opt.modeling_opt import _make_causal_mask
|
||||
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
_dtype,
|
||||
device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
|
||||
device
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def opt_model_forward(
|
||||
self: OPTModel,
|
||||
@@ -81,6 +72,7 @@ class OPTPipelineForwards:
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
|
||||
@@ -119,7 +111,7 @@ class OPTPipelineForwards:
|
||||
if decoder.project_in is not None:
|
||||
inputs_embeds = decoder.project_in(inputs_embeds)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
_dtype = inputs_embeds.dtype
|
||||
inputs_embeds.dtype
|
||||
|
||||
else:
|
||||
if hidden_states is None:
|
||||
@@ -127,7 +119,7 @@ class OPTPipelineForwards:
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
_dtype = hidden_states.dtype
|
||||
hidden_states.dtype
|
||||
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
# required mask seq length can be calculated via length of past
|
||||
@@ -141,13 +133,24 @@ class OPTPipelineForwards:
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
|
||||
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, _dtype, device, past_key_values_length
|
||||
)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
causal_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
attention_mask,
|
||||
)
|
||||
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
else:
|
||||
causal_attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
if decoder.gradient_checkpointing and decoder.training:
|
||||
if use_cache:
|
||||
@@ -249,7 +252,16 @@ class OPTPipelineForwards:
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
@@ -276,6 +288,7 @@ class OPTPipelineForwards:
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
|
||||
@@ -303,6 +316,7 @@ class OPTPipelineForwards:
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
@@ -347,6 +361,7 @@ class OPTPipelineForwards:
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
|
||||
@@ -371,6 +386,7 @@ class OPTPipelineForwards:
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
@@ -448,6 +464,7 @@ class OPTPipelineForwards:
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
|
||||
@@ -469,6 +486,7 @@ class OPTPipelineForwards:
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -511,49 +529,47 @@ class OPTPipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_opt_flash_attention_forward():
|
||||
def get_opt_flash_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: OPTAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[dict] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k, v, cross_attentions
|
||||
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
|
||||
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
@@ -565,38 +581,181 @@ def get_opt_flash_attention_forward():
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
if layer_head_mask != None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
attn_output = ColoAttention.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
**attention_mask,
|
||||
dropout_p=dropout_p,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||
|
||||
def forward(
|
||||
self: OPTDecoder,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
causal_attention_mask = _get_attention_mask(
|
||||
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
|
||||
)
|
||||
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
||||
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_opt_decoder_layer_forward():
|
||||
from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
||||
|
||||
|
Reference in New Issue
Block a user