mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -11,20 +11,29 @@ import triton.language as tl
|
||||
def _flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size * q_len, head_num, head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim],
|
||||
# or [num_blocks, num_kv_heads, head_dim//x, block_size, x], depends on strides provided
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
|
||||
kv_seq_len, # [batch_size]
|
||||
q_len,
|
||||
batch_size,
|
||||
kv_group_num,
|
||||
x,
|
||||
sm_scale,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_kcb,
|
||||
stride_kch,
|
||||
stride_kcsplit_x,
|
||||
stride_kcs,
|
||||
stride_kcd,
|
||||
stride_vcb,
|
||||
stride_vch,
|
||||
stride_vcs,
|
||||
stride_vcd,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
stride_mid_ot,
|
||||
@@ -34,8 +43,6 @@ def _flash_decoding_fwd_kernel(
|
||||
stride_mid_o_lset,
|
||||
stride_mid_o_lseh,
|
||||
stride_mid_o_lseb,
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
@@ -57,10 +64,9 @@ def _flash_decoding_fwd_kernel(
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
return
|
||||
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
offsets_block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# block table for the current sequence
|
||||
block_table_ptr = block_tables + cur_seq_idx * stride_bts
|
||||
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
|
||||
@@ -71,25 +77,25 @@ def _flash_decoding_fwd_kernel(
|
||||
)
|
||||
tl.device_assert(cur_occupied_size >= 0)
|
||||
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=KCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
cur_kv_head_idx = cur_head_idx // kv_group_num
|
||||
offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch
|
||||
offsets_k = (
|
||||
offset_kvcache
|
||||
+ (offsets_dmodel[None, :] // x) * stride_kcsplit_x
|
||||
+ (offsets_dmodel[None, :] % x) * stride_kcd
|
||||
+ offsets_block[:, None] * stride_kcs
|
||||
)
|
||||
k_cur_block = tl.load(KCache + offsets_k)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=VCache + offset_kvcache,
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
strides=(stride_vcs, stride_vcd),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
k_cur_block = tl.load(K_block_ptr)
|
||||
v_cur_block = tl.load(V_block_ptr)
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
# use block size of the paged/blocked kv cache
|
||||
@@ -100,7 +106,7 @@ def _flash_decoding_fwd_kernel(
|
||||
# Refer to https://github.com/openai/triton/discussions/895
|
||||
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||
S_ij *= sm_scale
|
||||
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
|
||||
S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float("-inf"))
|
||||
|
||||
m = tl.max(S_ij, 0)
|
||||
S_ij -= m
|
||||
@@ -324,6 +330,7 @@ def flash_decoding_attention(
|
||||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment.
|
||||
use_new_kcache_layout: bool = False,
|
||||
):
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
@@ -349,6 +356,7 @@ def flash_decoding_attention(
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
|
||||
Defaults to 1.
|
||||
use_new_kcache_layout (bool): Whether to use the new kcache layout. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz * q_len, num_heads * head_dim]
|
||||
@@ -400,13 +408,20 @@ def flash_decoding_attention(
|
||||
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (
|
||||
grid = lambda META: (
|
||||
triton.next_power_of_2(bsz * q_len),
|
||||
num_heads,
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
|
||||
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META["BLOCK_KV"]),
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
|
||||
# the code (alibi kernel) will be refactored later to avoid code duplication, when
|
||||
# the whole triton flow with new k cache layout has been supported and tested.
|
||||
assert (
|
||||
not use_new_kcache_layout
|
||||
), "Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready"
|
||||
|
||||
_alibi_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
@@ -441,6 +456,19 @@ def flash_decoding_attention(
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
else:
|
||||
# For KCache and VCache with the same layout
|
||||
x = head_dim
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)
|
||||
# For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]
|
||||
if use_new_kcache_layout:
|
||||
assert (
|
||||
k_cache.dim() == 5
|
||||
and k_cache.shape[1] == v_cache.shape[1]
|
||||
and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]
|
||||
), f"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}"
|
||||
x = k_cache.size(-1)
|
||||
kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]
|
||||
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
@@ -451,13 +479,21 @@ def flash_decoding_attention(
|
||||
kv_seq_len,
|
||||
q_len,
|
||||
bsz,
|
||||
kv_group_num,
|
||||
x,
|
||||
sm_scale,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
kcsplit_x_stride,
|
||||
kcs_stride,
|
||||
kcd_stride,
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_output.stride(0),
|
||||
@@ -467,8 +503,6 @@ def flash_decoding_attention(
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
|
Reference in New Issue
Block a user