mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[kernel] skip tests of flash_attn and triton when they are not available (#1798)
This commit is contained in:
parent
e34e850a4c
commit
c248800359
@ -61,7 +61,7 @@ class GeminiManager:
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
""" Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
|
@ -5,20 +5,24 @@ This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import subprocess
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
raise ImportError('please install triton from https://github.com/openai/triton')
|
||||
|
||||
print('please install triton from https://github.com/openai/triton')
|
||||
HAS_TRITON = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
HAS_FLASH_ATTN = False
|
||||
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
def triton_check():
|
||||
@ -33,22 +37,44 @@ def triton_check():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
TRITON_AVALIABLE = triton_check()
|
||||
|
||||
if TRITON_AVALIABLE:
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
TMP,
|
||||
L,
|
||||
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_on,
|
||||
Z,
|
||||
H,
|
||||
N_CTX,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
@ -117,13 +143,16 @@ def _fwd_kernel(
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out,
|
||||
DO,
|
||||
L,
|
||||
NewDO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
@ -137,21 +166,40 @@ def _bwd_preprocess(
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
DO,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
L,
|
||||
M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
Z,
|
||||
H,
|
||||
N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
@ -221,8 +269,7 @@ def _bwd_kernel(
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _TritonFlashAttention(torch.autograd.Function):
|
||||
class _TritonFlashAttention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
@ -239,16 +286,37 @@ class _TritonFlashAttention(torch.autograd.Function):
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
tmp,
|
||||
L,
|
||||
m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
o.stride(3),
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
q.shape[2],
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
@ -267,33 +335,56 @@ class _TritonFlashAttention(torch.autograd.Function):
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
|
||||
o,
|
||||
do,
|
||||
l,
|
||||
do_scaled,
|
||||
delta,
|
||||
BLOCK_M=ctx.BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# NOTE: kernel currently buggy for other values of `num_warps`
|
||||
num_warps = 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
o,
|
||||
do_scaled,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
l,
|
||||
m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3),
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
BLOCK_M=ctx.BLOCK,
|
||||
BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
def triton_flash_attention(q, k, v, sm_scale):
|
||||
def triton_flash_attention(q, k, v, sm_scale):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, nheads, seq, headdim)
|
||||
@ -309,7 +400,9 @@ def triton_flash_attention(q, k, v, sm_scale):
|
||||
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
|
||||
|
||||
|
||||
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
|
||||
"""
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
@ -327,5 +420,13 @@ def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal
|
||||
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
|
||||
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len,
|
||||
dropout_p=dropout_p, softmax_scale=sm_scale, causal=causal)
|
||||
return flash_attn_unpadded_func(q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=sm_scale,
|
||||
causal=causal)
|
||||
|
@ -1,7 +1,14 @@
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import flash_attention
|
||||
|
||||
if HAS_TRITON:
|
||||
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
@ -15,6 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
return ref_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
raise TypeError("Error type not match!")
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
@ -73,7 +82,8 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
|
||||
tri_out.backward(dout, retain_graph=True)
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv))
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
|
||||
(tri_out, tri_dq, tri_dk, tri_dv))
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user