mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -13,41 +13,74 @@ from transformers.modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
WhisperDecoder,
|
||||
WhisperEncoder,
|
||||
WhisperForAudioClassification,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperModel,
|
||||
shift_tokens_right,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_attention_mask(
|
||||
self: WhisperDecoder,
|
||||
shard_config: ShardConfig,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
attention_mask: Optional[torch.FloatTensor],
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
if shard_config.enable_flash_attention:
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
(batch_size, 1, seq_length, mask_seq_length),
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
|
||||
def get_whisper_flash_attention_forward():
|
||||
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
||||
|
||||
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
|
||||
|
||||
def forward(
|
||||
self: WhisperAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[dict] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
||||
# for encoder, attention_mask is None
|
||||
if attention_mask is None:
|
||||
attention_mask = {}
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
@@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward():
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[1] == key_value_states.shape[1]
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
@@ -85,38 +118,22 @@ def get_whisper_flash_attention_forward():
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
# get query proj
|
||||
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
|
||||
attn_type = None
|
||||
flash_attention_mask = None
|
||||
|
||||
if self.is_decoder:
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_type = AttnMaskType.causal
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
attn_output = ColoAttention.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
**attention_mask,
|
||||
dropout_p=dropout_p,
|
||||
scale=self.scaling,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
@@ -125,6 +142,158 @@ def get_whisper_flash_attention_forward():
|
||||
return forward
|
||||
|
||||
|
||||
def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self: WhisperDecoder,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
|
||||
|
||||
# embed positions
|
||||
if input_ids is not None:
|
||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (len(self.layers)), (
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, use_cache)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
None, # encoder attention mask
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
None, # past_key_value
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_whisper_encoder_layer_forward():
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
|
||||
|
||||
@@ -292,6 +461,7 @@ class WhisperPipelineForwards:
|
||||
all_attentions=None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -403,7 +573,9 @@ class WhisperPipelineForwards:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -411,7 +583,7 @@ class WhisperPipelineForwards:
|
||||
|
||||
@staticmethod
|
||||
def whisper_decoder_forward(
|
||||
self,
|
||||
self: WhisperDecoder,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
@@ -427,6 +599,7 @@ class WhisperPipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -535,8 +708,12 @@ class WhisperPipelineForwards:
|
||||
else:
|
||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
@@ -556,8 +733,12 @@ class WhisperPipelineForwards:
|
||||
)
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, hidden_states, past_key_values_length
|
||||
attention_mask = _get_attention_mask(
|
||||
self,
|
||||
shard_config,
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
@@ -590,7 +771,7 @@ class WhisperPipelineForwards:
|
||||
encoder_hidden_states,
|
||||
None, # encoder attention mask
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
None, # past_key_value
|
||||
)
|
||||
else:
|
||||
@@ -626,7 +807,13 @@ class WhisperPipelineForwards:
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
@@ -666,6 +853,7 @@ class WhisperPipelineForwards:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@@ -735,7 +923,7 @@ class WhisperPipelineForwards:
|
||||
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
@@ -767,6 +955,7 @@ class WhisperPipelineForwards:
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
|
||||
# Directly return outputs of overloaded Whisper forward if not at last stage.
|
||||
@@ -810,6 +999,7 @@ class WhisperPipelineForwards:
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@@ -870,6 +1060,7 @@ class WhisperPipelineForwards:
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
shard_config=shard_config,
|
||||
)
|
||||
if not in_decoder:
|
||||
return outputs
|
||||
@@ -920,6 +1111,7 @@ class WhisperPipelineForwards:
|
||||
all_attentions=None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
|
||||
|
Reference in New Issue
Block a user