mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Kernel/Fix] Revise flash attention triton kernel API and add benchmark (#5301)
* fix decoding kernel pytest * revise and add triton context attn benchmark
This commit is contained in:
@@ -5,6 +5,8 @@
|
||||
#
|
||||
# Inspired and modified from Triton Tutorial - Fused Attention
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -190,13 +192,8 @@ def context_attention_unpadded(
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
max_seq_len_in_b: Optional[int] = None,
|
||||
):
|
||||
# q/k in context stage are supposed to be put into k_cache and v_cache.
|
||||
# This step can be optimized in future.
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk == Lv
|
||||
assert Lk in {32, 64, 128, 256}
|
||||
@@ -210,7 +207,7 @@ def context_attention_unpadded(
|
||||
num_kv_group = num_heads // num_kv_heads
|
||||
|
||||
num_seqs, max_blocks_per_seq = block_tables.shape
|
||||
max_seq_len = context_lengths.max().item()
|
||||
max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
@@ -220,7 +217,7 @@ def context_attention_unpadded(
|
||||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_M = BLOCK_N = block_size
|
||||
|
||||
grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
q,
|
||||
|
Reference in New Issue
Block a user