mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[npu] use extension for op builder (#5172)
* update extension * update cpu adam * update is * add doc for cpu adam * update kernel * update commit * update flash * update memory efficient * update flash attn * update flash attention loader * update api * fix * update doc * update example time limit * reverse change * fix doc * remove useless kernel * fix * not use warning * update * update
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .mha.mha import ColoAttention
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
@@ -8,6 +7,5 @@ __all__ = [
|
||||
"MultiHeadAttention",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"ScaledUpperTriangMaskedSoftmax",
|
||||
"ColoAttention",
|
||||
"AttnMaskType",
|
||||
]
|
||||
|
@@ -1,3 +0,0 @@
|
||||
from .mha import ColoAttention
|
||||
|
||||
__all__ = ["ColoAttention"]
|
@@ -1,79 +0,0 @@
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_ampere_or_better_gpu():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
properties = torch.cuda.get_device_properties(device)
|
||||
if properties.major >= 8: # Ampere GPUs or newer
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# "Check Ampere GPUs or newer"
|
||||
HAS_FLASH_ATTN = False
|
||||
if is_ampere_or_better_gpu():
|
||||
HAS_FLASH_ATTN = True
|
||||
else:
|
||||
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
|
||||
HAS_FLASH_ATTN = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention")
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
def flash_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
if padded:
|
||||
if seq_len_info_kv == None:
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len_info_q.cu_seqlens,
|
||||
seq_len_info_kv.cu_seqlens,
|
||||
seq_len_info_q.max_seqlen,
|
||||
seq_len_info_kv.max_seqlen,
|
||||
dropout_p,
|
||||
scale,
|
||||
causal,
|
||||
)
|
||||
else:
|
||||
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
|
||||
return attn_out
|
@@ -1,70 +0,0 @@
|
||||
import warnings
|
||||
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
try:
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
"""
|
||||
A general attention module using the flash attention kernels from xformers:
|
||||
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
||||
|
||||
def mem_eff_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
attn_bias = None
|
||||
if padded: # bert style
|
||||
if not causal:
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
elif causal: # gpt style
|
||||
attn_bias = LowerTriangularMask()
|
||||
|
||||
if bias is not None: # alibi / relative position embedding
|
||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||
assert causal, "attention with bias is only supported for causal attention so far."
|
||||
attn_bias = attn_bias.add_bias(bias)
|
||||
|
||||
if padded:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
|
||||
|
||||
# shape: (b*s, n, d)
|
||||
if padded:
|
||||
out = out.squeeze(0)
|
||||
|
||||
return out
|
@@ -1,114 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from ..scaled_softmax import AttnMaskType
|
||||
from .flash_attn_2 import HAS_FLASH_ATTN
|
||||
from .mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from .utils import Repad, SeqLenInfo, Unpad
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from .flash_attn_2 import flash_attention
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from .mem_eff_attn import mem_eff_attention
|
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
||||
super().__init__()
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
|
||||
raise Exception("flash attention can not support!")
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
attn = None
|
||||
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
|
||||
attn = flash_attention
|
||||
else:
|
||||
attn = mem_eff_attention
|
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
# unpad
|
||||
seq_len_info_q = None
|
||||
seq_len_info_kv = None
|
||||
if padded:
|
||||
# bert style, unpad process
|
||||
assert (
|
||||
attn_mask is not None
|
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, (
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), "
|
||||
+ f"but got {attn_mask.dim()} dimensions."
|
||||
)
|
||||
|
||||
# bert style
|
||||
if tgt_len == src_len:
|
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(
|
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
||||
).unbind(dim=1)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
else:
|
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
||||
dim=1
|
||||
)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
|
||||
out = attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
seq_len_info_q,
|
||||
seq_len_info_kv,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
causal=causal,
|
||||
padded=padded,
|
||||
)
|
||||
|
||||
# repad
|
||||
if padded:
|
||||
if batch_size > 1:
|
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
||||
|
||||
out = rearrange(out, "b s h d -> b s (h d)")
|
||||
return out
|
@@ -1,82 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
||||
ctx.save_for_backward(indices)
|
||||
# [b, s, ...]
|
||||
assert tensor.ndim >= 3
|
||||
ctx.bsz = tensor.shape[0]
|
||||
out = rearrange(tensor, "b s ... -> (b s) ...")
|
||||
ctx.shape = out.shape
|
||||
# [ntokens, ...]
|
||||
return out[indices]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [ntokens, ...]
|
||||
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
grad[indices] = grad_output
|
||||
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
|
||||
# [b, s, ...]
|
||||
return grad, None
|
||||
|
||||
|
||||
class Repad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
||||
ctx.save_for_backward(indices)
|
||||
# [ntokens, ...]
|
||||
tensor = tensor
|
||||
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
||||
# [b*s, ...]
|
||||
out[indices] = tensor
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [b*s, ...]
|
||||
grad = grad_output[indices]
|
||||
# [ntokens, ...]
|
||||
return grad, None, None, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeqLenInfo:
|
||||
seqlens: Iterable[int] = None
|
||||
indices: torch.Tensor = None
|
||||
max_seqlen: int = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
|
||||
@staticmethod
|
||||
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
|
||||
if attn_mask is not None:
|
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
||||
else:
|
||||
batch_size, tgt_len = size[0], size[1]
|
||||
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
|
||||
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
|
||||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
||||
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|
Reference in New Issue
Block a user