mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-19 00:16:51 +00:00
[shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests
This commit is contained in:
@@ -1,209 +0,0 @@
|
||||
import enum
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.kernel.kernel_loader import FlashAttentionLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeqLenInfo:
|
||||
seqlens: Iterable[int] = None
|
||||
indices: torch.Tensor = None
|
||||
max_seqlen: int = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
|
||||
@staticmethod
|
||||
def materialize(
|
||||
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
|
||||
):
|
||||
if attn_mask is not None:
|
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
||||
else:
|
||||
batch_size, tgt_len = size[0], size[1]
|
||||
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
|
||||
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
|
||||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
||||
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
||||
paddedcausal = 3
|
||||
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
||||
ctx.save_for_backward(indices)
|
||||
# [b, s, ...]
|
||||
assert tensor.ndim >= 3
|
||||
ctx.bsz = tensor.shape[0]
|
||||
out = rearrange(tensor, "b s ... -> (b s) ...")
|
||||
ctx.shape = out.shape
|
||||
# [ntokens, ...]
|
||||
return out[indices]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [ntokens, ...]
|
||||
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
grad[indices] = grad_output
|
||||
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
|
||||
# [b, s, ...]
|
||||
return grad, None
|
||||
|
||||
|
||||
class Repad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
||||
ctx.save_for_backward(indices)
|
||||
# [ntokens, ...]
|
||||
tensor = tensor
|
||||
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
||||
# [b*s, ...]
|
||||
out[indices] = tensor
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [b*s, ...]
|
||||
grad = grad_output[indices]
|
||||
# [ntokens, ...]
|
||||
return grad, None, None, None
|
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
||||
super().__init__()
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
self.attn = FlashAttentionLoader().load()
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
ColoAttention
|
||||
|
||||
Args:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
|
||||
bias: will not be used
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
# if flash attention is not applicable, switch to memory effcient attention
|
||||
if self.attn.__name__ == "flash_attention" and (
|
||||
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
|
||||
):
|
||||
warnings.warn(
|
||||
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
|
||||
)
|
||||
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
|
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
# unpad
|
||||
seq_len_info_q = None
|
||||
seq_len_info_kv = None
|
||||
if padded:
|
||||
# bert style, unpad process
|
||||
assert (
|
||||
attn_mask is not None
|
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, (
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), "
|
||||
+ f"but got {attn_mask.dim()} dimensions."
|
||||
)
|
||||
|
||||
# bert style
|
||||
if tgt_len == src_len:
|
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(
|
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
||||
).unbind(dim=1)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
else:
|
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
||||
dim=1
|
||||
)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
|
||||
out = self.attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
seq_len_info_q=seq_len_info_q,
|
||||
seq_len_info_kv=seq_len_info_kv,
|
||||
origin_attn_mask=origin_attn_mask,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
causal=causal,
|
||||
padded=padded,
|
||||
)
|
||||
|
||||
# repad
|
||||
if padded:
|
||||
if batch_size > 1:
|
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
||||
|
||||
if len(out.shape) == 4:
|
||||
out = rearrange(out, "b s h d -> b s (h d)")
|
||||
return out
|
Reference in New Issue
Block a user