mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
stride_kd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cached,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
@@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
last_bt_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cacheb
|
||||
+ cur_kv_head_idx * stride_cacheh
|
||||
+ offsets_in_last_block * stride_cachebs
|
||||
+ offsets_dmodel * stride_cached
|
||||
+ offsets_in_last_block
|
||||
)
|
||||
tl.store(KVCache + offsets_kvcache, kv)
|
||||
return
|
||||
@@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
Args:
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
if k.dim() == 4:
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
elif k.dim() == 3:
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
else:
|
||||
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
|
||||
|
||||
k = k.squeeze(1) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
@@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
|
||||
)
|
||||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-1)
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
|
Reference in New Issue
Block a user