[kernel] skip tests of flash_attn and triton when they are not available (#1798)

This commit is contained in:
Jiarui Fang 2022-11-07 13:41:13 +08:00 committed by GitHub
parent e34e850a4c
commit c248800359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 301 deletions

View File

@ -61,7 +61,7 @@ class GeminiManager:
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided """ Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model. by mem_stats_collector, which should belongs to a Sharded Model.
""" """
# find stateful tensor in state COMPUTE # find stateful tensor in state COMPUTE

View File

@ -5,20 +5,24 @@ This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton) (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
""" """
import torch
import subprocess
import os import os
import subprocess
import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True
except ImportError: except ImportError:
raise ImportError('please install triton from https://github.com/openai/triton') print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try: try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
HAS_FLASH_ATTN = True
except ImportError: except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention') HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def triton_check(): def triton_check():
@ -33,20 +37,42 @@ def triton_check():
return True return True
return False return False
TRITON_AVALIABLE = triton_check() TRITON_AVALIABLE = triton_check()
if TRITON_AVALIABLE:
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, sm_scale, Q,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out, Out,
stride_qz, stride_qh, stride_qm, stride_qk, stride_qz,
stride_kz, stride_kh, stride_kn, stride_kk, stride_qh,
stride_vz, stride_vh, stride_vk, stride_vn, stride_qm,
stride_oz, stride_oh, stride_om, stride_on, stride_qk,
Z, H, N_CTX, stride_kz,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
@ -117,12 +143,15 @@ def _fwd_kernel(
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc) tl.store(out_ptrs, acc)
@triton.jit @triton.jit
def _bwd_preprocess( def _bwd_preprocess(
Out, DO, L, Out,
NewDO, Delta, DO,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
): ):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD) off_n = tl.arange(0, D_HEAD)
@ -137,19 +166,38 @@ def _bwd_preprocess(
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta) tl.store(Delta + off_m, delta)
@triton.jit @triton.jit
def _bwd_kernel( def _bwd_kernel(
Q, K, V, sm_scale, Out, DO, Q,
DQ, DK, DV, K,
L, M, V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D, D,
stride_qz, stride_qh, stride_qm, stride_qk, stride_qz,
stride_kz, stride_kh, stride_kn, stride_kk, stride_qh,
stride_vz, stride_vh, stride_vk, stride_vn, stride_qm,
Z, H, N_CTX, stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
off_hz = tl.program_id(0) off_hz = tl.program_id(0)
@ -221,7 +269,6 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) tl.store(dk_ptrs, dk)
class _TritonFlashAttention(torch.autograd.Function): class _TritonFlashAttention(torch.autograd.Function):
@staticmethod @staticmethod
@ -239,16 +286,37 @@ class _TritonFlashAttention(torch.autograd.Function):
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid]( _fwd_kernel[grid](
q, k, v, sm_scale, q,
tmp, L, m, k,
v,
sm_scale,
tmp,
L,
m,
o, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), k.stride(3), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.stride(2),
o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.stride(3),
q.shape[0], q.shape[1], q.shape[2], k.stride(0),
BLOCK_M=BLOCK, BLOCK_N=BLOCK, k.stride(1),
BLOCK_DMODEL=Lk, num_warps=num_warps, k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1, num_stages=1,
) )
ctx.save_for_backward(q, k, v, o, L, m) ctx.save_for_backward(q, k, v, o, L, m)
@ -268,31 +336,54 @@ class _TritonFlashAttention(torch.autograd.Function):
do_scaled = torch.empty_like(do) do_scaled = torch.empty_like(do)
delta = torch.empty_like(l) delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o, do, l, o,
do_scaled, delta, do,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
) )
# NOTE: kernel currently buggy for other values of `num_warps` # NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8 num_warps = 8
_bwd_kernel[(ctx.grid[1],)]( _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale, q,
o, do_scaled, k,
dq, dk, dv, v,
l, m, ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), k.stride(3), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.stride(2),
q.shape[0], q.shape[1], q.shape[2], q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0], ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_M=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return dq, dk, dv, None return dq, dk, dv, None
def triton_flash_attention(q, k, v, sm_scale): def triton_flash_attention(q, k, v, sm_scale):
""" """
Arguments: Arguments:
@ -309,6 +400,8 @@ def triton_flash_attention(q, k, v, sm_scale):
raise RuntimeError("Triton kernel requires CUDA 11.4+!") raise RuntimeError("Triton kernel requires CUDA 11.4+!")
if HAS_FLASH_ATTN:
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True): def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
""" """
Arguments: Arguments:
@ -327,5 +420,13 @@ def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device) lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32) cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0) cu_seqlens[1:] = lengths.cumsum(0)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, return flash_attn_unpadded_func(q,
dropout_p=dropout_p, softmax_scale=sm_scale, causal=causal) k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
dropout_p=dropout_p,
softmax_scale=sm_scale,
causal=causal)

View File

@ -1,7 +1,14 @@
import torch
import pytest import pytest
import torch
from einops import rearrange from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@ -15,6 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise TypeError("Error type not match!") raise TypeError("Error type not match!")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)]) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16): def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20) torch.manual_seed(20)
@ -73,7 +82,8 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
dout = rearrange(dout, 'z h n d -> (z n) h d').detach() dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
tri_out.backward(dout, retain_graph=True) tri_out.backward(dout, retain_graph=True)
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout) tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv)) tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
# compare # compare
assert torch.allclose(ref_out, tri_out, atol=1e-3) assert torch.allclose(ref_out, tri_out, atol=1e-3)