mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -3,6 +3,8 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import ColoAttention
|
||||
|
||||
|
||||
def forward_fn():
|
||||
def forward(
|
||||
@@ -62,8 +64,6 @@ def forward_fn():
|
||||
def get_blip2_flash_attention_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
||||
|
||||
from colossalai.nn.layer.colo_attention import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Blip2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward():
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
assert head_mask is None, "head_mask is not supported in FlashAttention"
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
mixed_qkv = self.qkv(hidden_states)
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
query_states, key_states, value_states = (
|
||||
mixed_qkv[0],
|
||||
mixed_qkv[1],
|
||||
mixed_qkv[2],
|
||||
)
|
||||
context_layer = attention(query_states, key_states, value_states)
|
||||
|
||||
dropout_p = self.dropout.p if self.training else 0.0
|
||||
context_layer = ColoAttention.attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout_p,
|
||||
scale=self.scale,
|
||||
)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
outputs = (output, None)
|
||||
@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward():
|
||||
def get_jit_fused_blip2_QFormer_self_output_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
|
||||
|
||||
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self: Blip2QFormerSelfOutput,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward():
|
||||
def get_jit_fused_blip2_QFormer_output_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
|
||||
|
||||
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self: Blip2QFormerOutput,
|
||||
hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
|
Reference in New Issue
Block a user