[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)

* fix decoding kernel pytest

* revise and add triton context attn benchmark
This commit is contained in:
Yuanheng Zhao
2024-01-23 17:16:02 +08:00
committed by GitHub
parent 8e606ecc7e
commit 3da9993b0d
5 changed files with 116 additions and 15 deletions

View File

@@ -5,6 +5,8 @@
#
# Inspired and modified from Triton Tutorial - Fused Attention
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
from typing import Optional
import torch
import triton
import triton.language as tl
@@ -190,13 +192,8 @@ def context_attention_unpadded(
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
max_seq_len_in_b: Optional[int] = None,
):
# q/k in context stage are supposed to be put into k_cache and v_cache.
# This step can be optimized in future.
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk == Lv
assert Lk in {32, 64, 128, 256}
@@ -210,7 +207,7 @@ def context_attention_unpadded(
num_kv_group = num_heads // num_kv_heads
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item()
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
sm_scale = 1.0 / (Lq**0.5)
output = torch.zeros_like(q)
@@ -220,7 +217,7 @@ def context_attention_unpadded(
assert block_size in {16, 32, 64, 128}
BLOCK_M = BLOCK_N = block_size
grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M))
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
q,

View File

@@ -215,10 +215,9 @@ def flash_decoding_attention(
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
"""
if q.dim() == 3:
bsz, num_heads, head_dim = q.shape
else:
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
bsz, num_heads, head_dim = q.shape
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (