[kernel] more flexible flashatt interface (#1804)

This commit is contained in:
oahzxl
2022-11-07 17:02:09 +08:00
committed by GitHub
parent 20e255d4e8
commit 9639ea88fc
2 changed files with 121 additions and 49 deletions

View File

@@ -11,7 +11,7 @@ import subprocess
import torch
def triton_check():
def triton_cuda_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
cuda_version = cuda_version.split('release ')[1]
@@ -27,7 +27,7 @@ def triton_check():
try:
import triton
import triton.language as tl
if triton_check():
if triton_cuda_check():
HAS_TRITON = True
else:
print("triton requires cuda >= 11.4")
@@ -36,7 +36,11 @@ except ImportError:
print('please install triton from https://github.com/openai/triton')
HAS_TRITON = False
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_func,
flash_attn_unpadded_kvpacked_func,
flash_attn_unpadded_qkvpacked_func,
)
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
@@ -405,12 +409,63 @@ if HAS_TRITON:
if HAS_FLASH_ATTN:
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
qkv: (batch * seqlen, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, dropout_p,
softmax_scale=sm_scale, causal=causal
)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
kv: (batch * kv_seqlen, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
dropout_p: float.
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q,
kv,
cu_seqlens_q,
cu_seqlens_k,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)
return out
def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
"""
Arguments:
q: (batch * q_seqlen, nheads, headdim)
k: (batch * kv_seqlen, nheads, headdim)
v: (batch * kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
@@ -420,16 +475,15 @@ if HAS_FLASH_ATTN:
Return:
out: (total, nheads, headdim).
"""
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=k.device)
return flash_attn_unpadded_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
dropout_p=dropout_p,
softmax_scale=sm_scale,
causal=causal)
cu_seqlens_q,
cu_seqlens_kv,
q_seqlen,
kv_seqlen,
dropout_p,
sm_scale,
causal)