mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
|
||||
|
||||
def _encoder_forward(
|
||||
@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
|
||||
pixel_values = pixel_values.to(expected_dtype)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
hidden_states = embedding_output
|
||||
else:
|
||||
@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
|
||||
def get_vit_flash_self_attention_forward():
|
||||
from transformers.models.vit.modeling_vit import ViTSelfAttention
|
||||
|
||||
from colossalai.nn.layer.colo_attention import ColoAttention
|
||||
|
||||
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self: ViTSelfAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
assert head_mask is None, "head_mask is not supported for FlashAttention"
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
|
||||
value_layer = transpose_for_scores(
|
||||
self.value(hidden_states), self.num_attention_heads, self.attention_head_size
|
||||
)
|
||||
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
scale = 1.0 / math.sqrt(self.attention_head_size)
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
|
||||
)
|
||||
context_layer = attention(query_layer, key_layer, value_layer)
|
||||
dropout_p = self.dropout.p if self.training else 0.0
|
||||
context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
|
||||
|
||||
outputs = (context_layer,)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, None) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
Reference in New Issue
Block a user