[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -1,4 +1,3 @@
import math
from typing import List, Optional, Tuple, Union
import torch
@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
def _encoder_forward(
@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding,
)
hidden_states = embedding_output
else:
@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x
def forward(
self: ViTSelfAttention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
assert head_mask is None, "head_mask is not supported for FlashAttention"
mixed_query_layer = self.query(hidden_states)
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
value_layer = transpose_for_scores(
self.value(hidden_states), self.num_attention_heads, self.attention_head_size
)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
scale = 1.0 / math.sqrt(self.attention_head_size)
attention = ColoAttention(
embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
)
context_layer = attention(query_layer, key_layer, value_layer)
dropout_p = self.dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
outputs = (context_layer,)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, None) if output_attentions else (context_layer,)
return outputs