[kernel] skip tests of flash_attn and triton when they are not available (#1798)

This commit is contained in:
Jiarui Fang 2022-11-07 13:41:13 +08:00 committed by GitHub
parent e34e850a4c
commit c248800359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 301 deletions

View File

@ -61,7 +61,7 @@ class GeminiManager:
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
""" Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE

View File

@ -5,20 +5,24 @@ This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import torch
import subprocess
import os
import subprocess
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
raise ImportError('please install triton from https://github.com/openai/triton')
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
HAS_FLASH_ATTN = True
except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def triton_check():
@ -33,22 +37,44 @@ def triton_check():
return True
return False
TRITON_AVALIABLE = triton_check()
if TRITON_AVALIABLE:
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
@ -117,13 +143,16 @@ def _fwd_kernel(
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
@triton.jit
def _bwd_preprocess(
Out,
DO,
L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
@ -137,21 +166,40 @@ def _bwd_preprocess(
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
@triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
@ -221,8 +269,7 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _TritonFlashAttention(torch.autograd.Function):
class _TritonFlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
@ -239,16 +286,37 @@ class _TritonFlashAttention(torch.autograd.Function):
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
tmp, L, m,
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
@ -267,33 +335,56 @@ class _TritonFlashAttention(torch.autograd.Function):
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
def triton_flash_attention(q, k, v, sm_scale):
def triton_flash_attention(q, k, v, sm_scale):
"""
Arguments:
q: (batch, nheads, seq, headdim)
@ -309,7 +400,9 @@ def triton_flash_attention(q, k, v, sm_scale):
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
if HAS_FLASH_ATTN:
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
"""
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
@ -327,5 +420,13 @@ def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len,
dropout_p=dropout_p, softmax_scale=sm_scale, causal=causal)
return flash_attn_unpadded_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
dropout_p=dropout_p,
softmax_scale=sm_scale,
causal=causal)

View File

@ -1,7 +1,14 @@
import torch
import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@ -15,6 +22,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise TypeError("Error type not match!")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
@ -73,7 +82,8 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
tri_out.backward(dout, retain_graph=True)
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv))
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)