mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy * add TODO in test_kvcache_copy.py
This commit is contained in:
@@ -6,17 +6,26 @@ import triton.language as tl
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _copy_to_kvcache_seqlen1_kernel(
|
||||
KV, # K or V
|
||||
KVCache, # KCache or VCache
|
||||
K, # K
|
||||
V, # V
|
||||
KCache, # KCache
|
||||
VCache, # VCache
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
stride_kt,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_vt,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_cachekb,
|
||||
stride_cachekh,
|
||||
stride_cachekbs,
|
||||
stride_cachekd,
|
||||
stride_cachevb,
|
||||
stride_cachevh,
|
||||
stride_cachevbs,
|
||||
stride_cachevd,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
@@ -32,20 +41,33 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
|
||||
k = tl.load(K + offsets_kv)
|
||||
v = tl.load(V + offsets_kv)
|
||||
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cacheb
|
||||
+ cur_kv_head_idx * stride_cacheh
|
||||
+ offsets_in_last_block * stride_cachebs
|
||||
+ offsets_dmodel * stride_cached
|
||||
block_id * stride_cachekb
|
||||
+ cur_kv_head_idx * stride_cachekh
|
||||
+ offsets_in_last_block * stride_cachekbs
|
||||
+ offsets_dmodel * stride_cachekd
|
||||
)
|
||||
tl.store(KVCache + offsets_kvcache, kv)
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cachevb
|
||||
+ cur_kv_head_idx * stride_cachevh
|
||||
+ offsets_in_last_block * stride_cachevbs
|
||||
+ offsets_dmodel * stride_cachevd
|
||||
)
|
||||
|
||||
tl.store(KCache + offsets_kvcache, k)
|
||||
tl.store(VCache + offsets_kvcache, v)
|
||||
return
|
||||
|
||||
|
||||
def copy_kv_to_blocked_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
kv_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
):
|
||||
@@ -53,16 +75,23 @@ def copy_kv_to_blocked_cache(
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
|
||||
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
|
||||
k = k.squeeze(1) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||
|
||||
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
|
||||
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
v = v.squeeze(1) if v.dim() == 4 else v
|
||||
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
|
||||
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
@@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
grid = (bsz, num_kv_heads)
|
||||
_copy_to_kvcache_seqlen1_kernel[grid](
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
kv_lengths,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
block_size,
|
||||
|
Reference in New Issue
Block a user