[kernel] Add flash decoding triton kernel for blocked kv cache (#5249)

* add flash decoding unpad triton kernel

* rename flash decoding kernel

* add kernel testing (draft)

* revise pytest

* support kv group (GQA)

* (trivial) fix api and pytest

* (trivial) func renaming

* (trivial) func/file renaming

* refactor pytest for attention

* (trivial) format and consistent vars of context/decode attn

* (trivial) remove test redundancy
This commit is contained in:
Yuanheng Zhao
2024-01-11 18:06:39 +08:00
committed by FrankLeeeee
parent fded91d049
commit 1513f20f4d
6 changed files with 576 additions and 153 deletions

View File

@@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel(
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
@@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel(
for i in range(0, cur_seq_idx):
prev_seq_len_sum += tl.load(context_lengths + i)
q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
base=Q + offset_q,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_qt, stride_qd),
offsets=(block_start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + kv_offset,
shape=(BLOCK_DMODEL, cur_seq_len),
base=K + offset_kv,
shape=(HEAD_DIM, cur_seq_len),
strides=(stride_kd, stride_kt),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=V + kv_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
base=V + offset_kv,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_vt, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
O_block_ptr = tl.make_block_ptr(
base=O + q_offset,
shape=(cur_seq_len, BLOCK_DMODEL),
base=O + offset_q,
shape=(cur_seq_len, HEAD_DIM),
strides=(stride_ot, stride_od),
offsets=(block_start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
@@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel(
# as we have BLOCK_M the same size as the block size.
cur_block_table_idx = block_start_m
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_n = tl.arange(0, BLOCK_N)
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
if block_start_m * BLOCK_M >= cur_seq_len:
return
@@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel(
if cur_head_idx % KV_GROUPS == 0:
# Copy k to corresponding cache block
kd_offsets = tl.arange(0, BLOCK_DMODEL)
kt_offsets = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
k_offsets = K + kv_offset + kd_offsets[:, None] * stride_kd + kt_offsets[None, :] * stride_kt
k = tl.load(k_offsets, mask=kt_offsets[None, :] < cur_seq_len, other=0.0)
kcached_offsets = tl.arange(0, BLOCK_DMODEL)
kcachebs_offsets = tl.arange(0, BLOCK_SIZE)
kcache_offsets = (
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ kvcache_offset
+ kcached_offsets[:, None] * stride_cached
+ kcachebs_offsets[None, :] * stride_cachebs
+ offset_kvcache
+ offsets_dmodel[:, None] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs
)
tl.store(kcache_offsets, k, mask=kcachebs_offsets[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
vd_offsets = kd_offsets
vt_offsets = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
v_offsets = V + kv_offset + vt_offsets[:, None] * stride_vt + vd_offsets[None, :] * stride_vd
v = tl.load(v_offsets, mask=vt_offsets[:, None] < cur_seq_len, other=0.0)
vcached_offsets = kcached_offsets
vcachebs_offsets = kcachebs_offsets
vcache_offsets = (
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ kvcache_offset
+ vcachebs_offsets[:, None] * stride_cachebs
+ vcached_offsets[None, :] * stride_cached
+ offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
)
tl.store(vcache_offsets, v, mask=vcachebs_offsets[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
return
def context_attention_unpadded(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
block_size: int,
@@ -254,7 +252,7 @@ def context_attention_unpadded(
sm_scale,
num_kv_group,
block_size,
BLOCK_DMODEL=Lk,
HEAD_DIM=Lk,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)