[Inference/opt] Fused KVCahce Memcopy (#5374)

* fused kv memcopy

* add TODO in test_kvcache_copy.py
This commit is contained in:
yuehuayingxueluo
2024-02-07 17:15:42 +08:00
committed by GitHub
parent 58740b5f68
commit 6fb4bcbb24
4 changed files with 75 additions and 30 deletions

View File

@@ -6,17 +6,26 @@ import triton.language as tl
# Triton 2.1.0
@triton.jit
def _copy_to_kvcache_seqlen1_kernel(
KV, # K or V
KVCache, # KCache or VCache
K, # K
V, # V
KCache, # KCache
VCache, # VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_vt,
stride_vh,
stride_vd,
stride_cachekb,
stride_cachekh,
stride_cachekbs,
stride_cachekd,
stride_cachevb,
stride_cachevh,
stride_cachevbs,
stride_cachevd,
stride_bts,
stride_btb,
block_size,
@@ -32,20 +41,33 @@ def _copy_to_kvcache_seqlen1_kernel(
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
block_id * stride_cachekb
+ cur_kv_head_idx * stride_cachekh
+ offsets_in_last_block * stride_cachekbs
+ offsets_dmodel * stride_cachekd
)
tl.store(KVCache + offsets_kvcache, kv)
offsets_kvcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)
tl.store(KCache + offsets_kvcache, k)
tl.store(VCache + offsets_kvcache, v)
return
def copy_kv_to_blocked_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_lengths: torch.Tensor,
block_tables: torch.Tensor,
):
@@ -53,16 +75,23 @@ def copy_kv_to_blocked_cache(
Copy keys or values to the blocked key/value cache during decoding stage.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
v = v.squeeze(1) if v.dim() == 4 else v
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"
bsz, num_kv_heads, head_dim = k.shape
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
@@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)
num_warps = 8 if head_dim > 128 else 4
grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
v,
k_cache,
v_cache,
block_tables,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,