mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
132 lines
4.5 KiB
Python
132 lines
4.5 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from colossalai.shardformer.layer import ColoAttention
|
|
|
|
|
|
def forward_fn():
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
|
|
mixed_qkv = self.qkv(hidden_states)
|
|
|
|
# modified from original code, which is:
|
|
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
|
# 2, 0, 3, 1, 4
|
|
# )
|
|
# to:
|
|
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
query_states, key_states, value_states = (
|
|
mixed_qkv[0],
|
|
mixed_qkv[1],
|
|
mixed_qkv[2],
|
|
)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
|
|
|
attention_scores = attention_scores * self.scale
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
|
context_layer = context_layer.reshape(new_context_layer_shape)
|
|
|
|
output = self.projection(context_layer)
|
|
|
|
outputs = (output, attention_probs) if output_attentions else (output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
def get_blip2_flash_attention_forward():
|
|
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
|
|
|
def forward(
|
|
self: Blip2Attention,
|
|
hidden_states: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
assert head_mask is None, "head_mask is not supported in FlashAttention"
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
mixed_qkv = self.qkv(hidden_states)
|
|
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
query_states, key_states, value_states = (
|
|
mixed_qkv[0],
|
|
mixed_qkv[1],
|
|
mixed_qkv[2],
|
|
)
|
|
|
|
dropout_p = self.dropout.p if self.training else 0.0
|
|
context_layer = ColoAttention.attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
dropout_p=dropout_p,
|
|
scale=self.scale,
|
|
)
|
|
context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
|
|
|
|
output = self.projection(context_layer)
|
|
outputs = (output, None)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_blip2_QFormer_self_output_forward():
|
|
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
|
|
|
|
def forward(
|
|
self: Blip2QFormerSelfOutput,
|
|
hidden_states: torch.Tensor,
|
|
input_tensor: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_blip2_QFormer_output_forward():
|
|
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
|
|
|
|
def forward(
|
|
self: Blip2QFormerOutput,
|
|
hidden_states: torch.Tensor,
|
|
input_tensor: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
return forward
|