mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest * revise and add triton context attn benchmark
This commit is contained in:
@@ -5,6 +5,8 @@
|
||||
#
|
||||
# Inspired and modified from Triton Tutorial - Fused Attention
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -190,13 +192,8 @@ def context_attention_unpadded(
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
max_seq_len_in_b: Optional[int] = None,
|
||||
):
|
||||
# q/k in context stage are supposed to be put into k_cache and v_cache.
|
||||
# This step can be optimized in future.
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk == Lv
|
||||
assert Lk in {32, 64, 128, 256}
|
||||
@@ -210,7 +207,7 @@ def context_attention_unpadded(
|
||||
num_kv_group = num_heads // num_kv_heads
|
||||
|
||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item()
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
@@ -220,7 +217,7 @@ def context_attention_unpadded(
|
||||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_M = BLOCK_N = block_size
|
||||
|
||||
grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
|
@@ -215,10 +215,9 @@ def flash_decoding_attention(
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
"""
|
||||
if q.dim() == 3:
|
||||
bsz, num_heads, head_dim = q.shape
|
||||
else:
|
||||
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
|
||||
q = q.squeeze() if q.dim() == 4 else q
|
||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||
bsz, num_heads, head_dim = q.shape
|
||||
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
|
Reference in New Issue
Block a user