mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[kernel] Revise KVCache copy triton kernel API (#5273)
* [kernel/fix] revise kvcache copy kernel api * fix benchmark
This commit is contained in:
@@ -25,11 +25,11 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
cur_kv_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
last_bt_block_idx = cur_kv_seq_len // block_size
|
||||
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1
|
||||
last_bt_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
@@ -43,23 +43,30 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
return
|
||||
|
||||
|
||||
# Used with blocked kv cache.
|
||||
# Copy k or v to block k/v cache during decoding stage
|
||||
def copy_kv_to_blocked_cache(
|
||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same)
|
||||
context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1)
|
||||
block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence]
|
||||
k: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f"batch size {bsz}"
|
||||
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
|
||||
)
|
||||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
@@ -74,7 +81,7 @@ def copy_kv_to_blocked_cache(
|
||||
k,
|
||||
k_cache,
|
||||
block_tables,
|
||||
context_lengths,
|
||||
kv_lengths,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
|
Reference in New Issue
Block a user