diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 1d5a6ce49..4910717b5 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,5 +1,8 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax -__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax'] +__all__ = [ + 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention' +] diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py deleted file mode 100644 index 3db737450..000000000 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ /dev/null @@ -1,635 +0,0 @@ -""" -A general attention module using the flash attention kernels from xformers: -https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha -""" - -import math -import os -import subprocess - -import torch - -try: - from xformers.ops.fmha import memory_efficient_attention - HAS_MEM_EFF_ATTN = True -except ImportError: - HAS_MEM_EFF_ATTN = False - print('please install xformers from https://github.com/facebookresearch/xformers') - -if HAS_MEM_EFF_ATTN: - - from typing import Optional - - from einops import rearrange - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp - from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias - - from .scaled_softmax import AttnMaskType - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, 'b s ... -> (b s) ...') - ctx.shape = out.shape - # [1, ntokens, ...] - return out[indices].unsqueeze(0) - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # [b*s, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output.squeeze(0) - grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) - # [b, s, ...] - return grad, None - - class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor.squeeze(0) - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - # [b, s, ...] - out = rearrange(out, '(b s) ... -> b s ...', b=batch_size) - return out - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # [b*s, ...] - grad_output = rearrange(grad_output, 'b s ... -> (b s) ...') - grad = grad_output[indices] - # [1, ntokens, ...] - return grad.unsqueeze(0), None, None, None - - class ColoAttention(torch.nn.Module): - - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): - super().__init__() - assert embed_dim % num_heads == 0, \ - f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - @staticmethod - def get_seq_info_from_mask(attn_mask: torch.Tensor): - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten() - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist() - return indices, seqlens - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None): - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - attn_bias = None - if attn_mask_type == AttnMaskType.padding: # bert style - assert attn_mask is not None, \ - f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, \ - "attention mask is supposed to have shape (batch_size, seq_len), " + \ - f"but got {attn_mask.dim()} dimensions." - if tgt_len == src_len: - q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask) - kv_seqlen = None - if batch_size > 1: - query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2) - else: - q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device) - q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device) - kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) - elif attn_mask_type == AttnMaskType.causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert attn_mask_type == AttnMaskType.causal, \ - "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) - - if attn_mask_type == AttnMaskType.padding and batch_size > 1: - out = self.repad(out, q_indices, batch_size, tgt_len) - - out = rearrange(out, 'b s h d -> b s (h d)') - return out - - -########################################################################## -# the flash attention functions below that are copied -# from the OpenAI/triton repository will be deprecated -# You can find the repository in Triton https://github.com/openai/triton -# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py -# Reference: -# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf -# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf - - -def triton_cuda_check(): - cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda") - cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip() - cuda_version = cuda_version.split('release ')[1] - cuda_version = cuda_version.split(',')[0] - cuda_version = cuda_version.split('.') - if len(cuda_version) == 2 and \ - (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \ - int(cuda_version[0]) > 11: - return True - return False - - -try: - import triton - import triton.language as tl - if triton_cuda_check(): - HAS_TRITON = True - else: - print("triton requires cuda >= 11.4") - HAS_TRITON = False -except ImportError: - print('please install triton from https://github.com/openai/triton') - HAS_TRITON = False -try: - from flash_attn.flash_attention import FlashAttention - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_func, - flash_attn_unpadded_kvpacked_func, - flash_attn_unpadded_qkvpacked_func, - ) - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False - print('please install flash_attn from https://github.com/HazyResearch/flash-attention') - -if HAS_TRITON: - # the following functions are adapted from the OpenAI Triton tutorial - # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py - @triton.jit - def _fwd_kernel( - Q, - K, - V, - sm_scale, - TMP, - L, - M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v - # initialize pointer to m and l - t_ptrs = TMP + off_hz * N_CTX + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) - # loop over k, v and update accumulator - for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) - qk *= sm_scale - qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(tl.float16) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # rematerialize offsets to save registers - start_m = tl.program_id(0) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m - l_ptrs = L + off_hz * N_CTX + offs_m - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(l_ptrs, l_i) - tl.store(m_ptrs, m_i) - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - - @triton.jit - def _bwd_preprocess( - Out, - DO, - L, - NewDO, - Delta, - BLOCK_M: tl.constexpr, - D_HEAD: tl.constexpr, - ): - off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_n = tl.arange(0, D_HEAD) - # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - denom = tl.load(L + off_m).to(tl.float32) - # compute - do = do / denom[:, None] - delta = tl.sum(o * do, axis=1) - # write-back - tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) - tl.store(Delta + off_m, delta) - - @triton.jit - def _bwd_kernel( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - M, - D, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - Z, - H, - N_CTX, - num_block, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_qz + off_h * stride_qh - V += off_z * stride_qz + off_h * stride_qh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_qz + off_h * stride_qh - DV += off_z * stride_qz + off_h * stride_qh - for start_n in range(0, num_block): - lo = start_n * BLOCK_M - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - m_ptrs = M + off_hz * N_CTX - # initialize dv amd dk - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - # NOTE: `do` is pre-divided by `l`; no normalization here - qk = tl.dot(q, k, trans_b=True) - qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) - m = tl.load(m_ptrs + offs_m_curr) - p = tl.exp(qk * sm_scale - m[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(p.to(tl.float16), do, trans_a=True) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v, trans_b=True) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(ds.to(tl.float16), q, trans_a=True) - # # compute dq - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds.to(tl.float16), k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - # # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - - class _TritonFlashAttention(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, sm_scale): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) - grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) - tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - tmp, - L, - m, - o, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, - num_warps=num_warps, - num_stages=1, - ) - ctx.save_for_backward(q, k, v, o, L, m) - ctx.BLOCK = BLOCK - ctx.grid = grid - ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, l, m = ctx.saved_tensors - do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - do_scaled = torch.empty_like(do) - delta = torch.empty_like(l) - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( - o, - do, - l, - do_scaled, - delta, - BLOCK_M=ctx.BLOCK, - D_HEAD=ctx.BLOCK_DMODEL, - ) - - # NOTE: kernel currently buggy for other values of `num_warps` - num_warps = 8 - _bwd_kernel[(ctx.grid[1],)]( - q, - k, - v, - ctx.sm_scale, - o, - do_scaled, - dq, - dk, - dv, - l, - m, - delta, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, - BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - num_warps=num_warps, - num_stages=1, - ) - return dq, dk, dv, None - - def triton_flash_attention(q, k, v, sm_scale): - """ - Arguments: - q: (batch, nheads, seq, headdim) - k: (batch, nheads, seq, headdim) - v: (batch, nheads, seq, headdim) - sm_scale: float. The scaling of QK^T before applying softmax. - Return: - out: (batch, nheads, seq, headdim) - """ - if HAS_TRITON: - return _TritonFlashAttention.apply(q, k, v, sm_scale) - else: - raise RuntimeError("Triton kernel requires CUDA 11.4+!") - - -if HAS_FLASH_ATTN: - - def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False): - """ - Arguments: - qkv: (batch * seqlen, 3, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - dropout_p: float. - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - max_s = seq_len - cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) - out = flash_attn_unpadded_qkvpacked_func(qkv, - cu_seqlens, - max_s, - dropout_p, - softmax_scale=sm_scale, - causal=causal) - return out - - def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): - """ - Arguments: - q: (batch * q_seqlen, nheads, headdim) - kv: (batch * kv_seqlen, 2, nheads, headdim) - batch_size: int. - seq_len: int. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - dropout_p: float. - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, - step=kv_seqlen, - dtype=torch.int32, - device=kv.device) - out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p, - sm_scale, causal) - return out - - def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False): - """ - Arguments: - q: (batch * q_seqlen, nheads, headdim) - k: (batch * kv_seqlen, nheads, headdim) - v: (batch * kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - out: (total, nheads, headdim). - """ - cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device) - cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, - step=kv_seqlen, - dtype=torch.int32, - device=k.device) - return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale, - causal) - - -########################################################################## diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py new file mode 100644 index 000000000..6a8d74f70 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -0,0 +1,68 @@ +import warnings +from typing import Optional + +import torch + + +def is_ampere_or_better_gpu(): + if torch.cuda.is_available(): + device = torch.device("cuda") + properties = torch.cuda.get_device_properties(device) + if properties.major >= 8: # Ampere GPUs or newer + return True + return False + + +# "Check Ampere GPUs or newer" +HAS_FLASH_ATTN = False +if is_ampere_or_better_gpu(): + HAS_FLASH_ATTN = True +else: + warnings.warn('FlashAttention only supports Ampere GPUs or newer.') + HAS_FLASH_ATTN = False +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + HAS_FLASH_ATTN = True +except ImportError: + warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention') + HAS_FLASH_ATTN = False + +if HAS_FLASH_ATTN: + from einops import rearrange + + from .utils import SeqLenInfo + + def flash_attention(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0., + scale: float = None, + causal: bool = False, + padded: bool = False): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale, + causal) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py new file mode 100644 index 000000000..e83beb8b2 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py @@ -0,0 +1,70 @@ +import warnings + +HAS_MEM_EFF_ATTN = False +try: + from xformers.ops.fmha import memory_efficient_attention + HAS_MEM_EFF_ATTN = True +except ImportError: + warnings.warn('please install xformers from https://github.com/facebookresearch/xformers') + HAS_MEM_EFF_ATTN = False + +if HAS_MEM_EFF_ATTN: + """ + A general attention module using the flash attention kernels from xformers: + https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha + """ + from typing import Optional + + import torch + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + + from .utils import SeqLenInfo + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0., + scale: float = None, + causal: bool = False, + padded: bool = False): + + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, \ + "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py new file mode 100644 index 000000000..8f449a138 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -0,0 +1,107 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange + +from ..scaled_softmax import AttnMaskType +from .flash_attn_2 import HAS_FLASH_ATTN +from .mem_eff_attn import HAS_MEM_EFF_ATTN +from .utils import Repad, SeqLenInfo, Unpad + +if HAS_FLASH_ATTN: + from .flash_attn_2 import flash_attention +if HAS_MEM_EFF_ATTN: + from .mem_eff_attn import mem_eff_attention + + +class ColoAttention(torch.nn.Module): + + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert embed_dim % num_heads == 0, \ + f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: + raise Exception("flash attention can not support!") + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None): + + attn = None + if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: + attn = flash_attention + else: + attn = mem_eff_attention + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert attn_mask is not None, \ + f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, \ + "attention mask is supposed to have shape (batch_size, seq_len), " + \ + f"but got {attn_mask.dim()} dimensions." + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad(torch.stack([query, key, value], dim=2), + seq_len_info_q.indices).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), + seq_len_info_kv.indices).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = attn(query, + key, + value, + seq_len_info_q, + seq_len_info_kv, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, '(b s) h d -> b s h d', b=batch_size) + + out = rearrange(out, 'b s h d -> b s (h d)') + return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py new file mode 100644 index 000000000..e3e431fa7 --- /dev/null +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.utils.cuda import get_current_device + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, 'b s ... -> (b s) ...') + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 24e458bb3..41cd4b20f 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -19,6 +19,7 @@ except ImportError: class AttnMaskType(enum.Enum): padding = 1 causal = 2 + paddedcausal = 3 class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @@ -139,7 +140,7 @@ class FusedScaleMaskSoftmax(nn.Module): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: if attn_batches % batch_per_block == 0: return True else: @@ -151,7 +152,7 @@ class FusedScaleMaskSoftmax(nn.Module): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: assert sq == sk, "causal mask is only for self attention" # input is 3D tensor (attn_batches, sq, sk) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 7a28b0157..d41ccd832 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,11 +4,15 @@ import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN +from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize -if HAS_MEM_EFF_ATTN: - from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention +if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: + from colossalai.kernel.cuda_native.mha.mha import ColoAttention + from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + +DTYPE = [torch.float16, torch.bfloat16, torch.float32] def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @@ -22,10 +26,13 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): return ref_out -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(1, 8, 4, 16)]) +@parameterize('dtype', DTYPE) +def test_attention_gpt(proj_shape, dtype): + # TODO check output value + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -35,7 +42,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): qkv = c_attn(x) q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) assert list(y.shape) == [B, S, D] @@ -43,10 +54,12 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): y.backward(dy) -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 8, 4, 16)]) +@parameterize('dtype', DTYPE) +def test_attention_bert(proj_shape, dtype): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -67,10 +80,12 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): y.backward(dy) -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 8, 4, 16)]) +@parameterize('dtype', DTYPE) +def test_attention_no_mask(proj_shape, dtype): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -87,10 +102,12 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): y.backward(dy) -@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") +@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) -def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 24, 8, 4, 16)]) +@parameterize('dtype', DTYPE) +def test_cross_attention(proj_shape, dtype): + (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")