mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
* add flash decoding unpad triton kernel * rename flash decoding kernel * add kernel testing (draft) * revise pytest * support kv group (GQA) * (trivial) fix api and pytest * (trivial) func renaming * (trivial) func/file renaming * refactor pytest for attention * (trivial) format and consistent vars of context/decode attn * (trivial) remove test redundancy
This commit is contained in:
committed by
FrankLeeeee
parent
fded91d049
commit
1513f20f4d
@@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel(
|
||||
sm_scale,
|
||||
KV_GROUPS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
@@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel(
|
||||
for i in range(0, cur_seq_idx):
|
||||
prev_seq_len_sum += tl.load(context_lengths + i)
|
||||
|
||||
q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
|
||||
kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
|
||||
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
|
||||
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(cur_seq_len, BLOCK_DMODEL),
|
||||
base=Q + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_qt, stride_qd),
|
||||
offsets=(block_start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + kv_offset,
|
||||
shape=(BLOCK_DMODEL, cur_seq_len),
|
||||
base=K + offset_kv,
|
||||
shape=(HEAD_DIM, cur_seq_len),
|
||||
strides=(stride_kd, stride_kt),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
block_shape=(HEAD_DIM, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + kv_offset,
|
||||
shape=(cur_seq_len, BLOCK_DMODEL),
|
||||
base=V + offset_kv,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_vt, stride_vd),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
block_shape=(BLOCK_N, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=O + q_offset,
|
||||
shape=(cur_seq_len, BLOCK_DMODEL),
|
||||
base=O + offset_q,
|
||||
shape=(cur_seq_len, HEAD_DIM),
|
||||
strides=(stride_ot, stride_od),
|
||||
offsets=(block_start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
@@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel(
|
||||
# as we have BLOCK_M the same size as the block size.
|
||||
cur_block_table_idx = block_start_m
|
||||
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
|
||||
kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
|
||||
|
||||
offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offsets_n = tl.arange(0, BLOCK_N)
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
if block_start_m * BLOCK_M >= cur_seq_len:
|
||||
return
|
||||
@@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel(
|
||||
|
||||
if cur_head_idx % KV_GROUPS == 0:
|
||||
# Copy k to corresponding cache block
|
||||
kd_offsets = tl.arange(0, BLOCK_DMODEL)
|
||||
kt_offsets = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
k_offsets = K + kv_offset + kd_offsets[:, None] * stride_kd + kt_offsets[None, :] * stride_kt
|
||||
k = tl.load(k_offsets, mask=kt_offsets[None, :] < cur_seq_len, other=0.0)
|
||||
kcached_offsets = tl.arange(0, BLOCK_DMODEL)
|
||||
kcachebs_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
kcache_offsets = (
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||
offsets_kcache = (
|
||||
KCache
|
||||
+ kvcache_offset
|
||||
+ kcached_offsets[:, None] * stride_cached
|
||||
+ kcachebs_offsets[None, :] * stride_cachebs
|
||||
+ offset_kvcache
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
+ offsets_kcachebs[None, :] * stride_cachebs
|
||||
)
|
||||
tl.store(kcache_offsets, k, mask=kcachebs_offsets[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
# Copy v to corresponding cache block
|
||||
vd_offsets = kd_offsets
|
||||
vt_offsets = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
v_offsets = V + kv_offset + vt_offsets[:, None] * stride_vt + vd_offsets[None, :] * stride_vd
|
||||
v = tl.load(v_offsets, mask=vt_offsets[:, None] < cur_seq_len, other=0.0)
|
||||
vcached_offsets = kcached_offsets
|
||||
vcachebs_offsets = kcachebs_offsets
|
||||
vcache_offsets = (
|
||||
offsets_vd = offsets_dmodel
|
||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||
offsets_vcache = (
|
||||
VCache
|
||||
+ kvcache_offset
|
||||
+ vcachebs_offsets[:, None] * stride_cachebs
|
||||
+ vcached_offsets[None, :] * stride_cached
|
||||
+ offset_kvcache
|
||||
+ offsets_vcachebs[:, None] * stride_cachebs
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
)
|
||||
tl.store(vcache_offsets, v, mask=vcachebs_offsets[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def context_attention_unpadded(
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
v: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
|
||||
v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
|
||||
block_size: int,
|
||||
@@ -254,7 +252,7 @@ def context_attention_unpadded(
|
||||
sm_scale,
|
||||
num_kv_group,
|
||||
block_size,
|
||||
BLOCK_DMODEL=Lk,
|
||||
HEAD_DIM=Lk,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
Reference in New Issue
Block a user