[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-05-03 17:20:45 +08:00
committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
10 changed files with 428 additions and 206 deletions

View File

@@ -11,20 +11,29 @@ import triton.language as tl
def _flash_decoding_fwd_kernel(
Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim],
# or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size]
q_len,
batch_size,
kv_group_num,
x,
sm_scale,
stride_qt,
stride_qh,
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_kcb,
stride_kch,
stride_kcsplit_x,
stride_kcs,
stride_kcd,
stride_vcb,
stride_vch,
stride_vcs,
stride_vcd,
stride_bts,
stride_btb,
stride_mid_ot,
@@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
stride_mid_o_lset,
stride_mid_o_lseh,
stride_mid_o_lseb,
sm_scale,
KV_GROUPS: tl.constexpr,
BLOCK_KV: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
@@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
offsets_block = tl.arange(0, BLOCK_SIZE)
# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
@@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
)
tl.device_assert(cur_occupied_size >= 0)
cur_kv_head_idx = cur_head_idx // KV_GROUPS
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
cur_kv_head_idx = cur_head_idx // kv_group_num
offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
offsets_k = (
offset_kvcache
+ (offsets_dmodel[None, :] // x) * stride_kcsplit_x
+ (offsets_dmodel[None, :] % x) * stride_kcd
+ offsets_block[:, None] * stride_kcs
)
k_cur_block = tl.load(KCache + offsets_k)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
strides=(stride_vcs, stride_vcd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
v_cur_block = tl.load(V_block_ptr)
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
# use block size of the paged/blocked kv cache
@@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf"))
m = tl.max(S_ij, 0)
S_ij -= m
@@ -324,6 +330,7 @@ def flash_decoding_attention(
sm_scale: int = None,
kv_group_num: int = 1,
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
use_new_kcache_layout: bool = False,
):
"""
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@@ -349,6 +356,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
Defaults to 1.
use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
Returns:
Output tensor with shape [bsz * q_len, num_heads * head_dim]
@@ -400,13 +408,20 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (
grid = lambda META: (
triton.next_power_of_2(bsz * q_len),
num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]),
)
if alibi_slopes is not None:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
assert (
not use_new_kcache_layout
), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
_alibi_flash_decoding_fwd_kernel[grid](
q,
k_cache,
@@ -441,6 +456,19 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)
else:
# For KCache and VCache with the same layout
x = head_dim
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
if use_new_kcache_layout:
assert (
k_cache.dim() == 5
and k_cache.shape[1] == v_cache.shape[1]
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
x = k_cache.size(-1)
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
@@ -451,13 +479,21 @@ def flash_decoding_attention(
kv_seq_len,
q_len,
bsz,
kv_group_num,
x,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
kcsplit_x_stride,
kcs_stride,
kcd_stride,
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_output.stride(0),
@@ -467,8 +503,6 @@ def flash_decoding_attention(
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,