[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -13,41 +13,74 @@ from transformers.modeling_outputs import (
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
WhisperDecoder,
WhisperEncoder,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
shift_tokens_right,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: WhisperDecoder,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
if shard_config.enable_flash_attention:
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_length, mask_seq_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
return attention_mask
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
def forward(
self: WhisperAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
# for encoder, attention_mask is None
if attention_mask is None:
attention_mask = {}
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward():
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[1] == key_value_states.shape[1]
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -85,38 +118,22 @@ def get_whisper_flash_attention_forward():
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# get query proj
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz)
src_len = key_states.size(1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_type = None
flash_attention_mask = None
if self.is_decoder:
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
if not torch.all(flash_attention_mask):
attn_type = AttnMaskType.paddedcausal
else:
attn_type = AttnMaskType.causal
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type
dropout_p = self.dropout if self.training else 0.0
attn_output = ColoAttention.attention(
query_states,
key_states,
value_states,
**attention_mask,
dropout_p=dropout_p,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
@@ -125,6 +142,158 @@ def get_whisper_flash_attention_forward():
return forward
def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
def forward(
self: WhisperDecoder,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
return forward
def get_jit_fused_whisper_encoder_layer_forward():
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
@@ -292,6 +461,7 @@ class WhisperPipelineForwards:
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
):
r"""
Args:
@@ -403,7 +573,9 @@ class WhisperPipelineForwards:
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
else:
@@ -411,7 +583,7 @@ class WhisperPipelineForwards:
@staticmethod
def whisper_decoder_forward(
self,
self: WhisperDecoder,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
@@ -427,6 +599,7 @@ class WhisperPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
):
r"""
Args:
@@ -535,8 +708,12 @@ class WhisperPipelineForwards:
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
attention_mask = _get_attention_mask(
self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
)
hidden_states = inputs_embeds + positions
@@ -556,8 +733,12 @@ class WhisperPipelineForwards:
)
input_shape = hidden_states.size()[:-1]
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values_length,
attention_mask,
)
start_idx, end_idx = stage_index[0], stage_index[1]
@@ -590,7 +771,7 @@ class WhisperPipelineForwards:
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
)
else:
@@ -626,7 +807,13 @@ class WhisperPipelineForwards:
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
@@ -666,6 +853,7 @@ class WhisperPipelineForwards:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
):
r"""
Returns:
@@ -735,7 +923,7 @@ class WhisperPipelineForwards:
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
@@ -767,6 +955,7 @@ class WhisperPipelineForwards:
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
shard_config=shard_config,
)
# Directly return outputs of overloaded Whisper forward if not at last stage.
@@ -810,6 +999,7 @@ class WhisperPipelineForwards:
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -870,6 +1060,7 @@ class WhisperPipelineForwards:
encoder_hidden_states=encoder_hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
shard_config=shard_config,
)
if not in_decoder:
return outputs
@@ -920,6 +1111,7 @@ class WhisperPipelineForwards:
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
):
r"""
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.