[Inference/opt] Fused KVCahce Memcopy (#5374)

* fused kv memcopy

* add TODO in test_kvcache_copy.py
This commit is contained in:
yuehuayingxueluo
2024-02-07 17:15:42 +08:00
committed by GitHub
parent 58740b5f68
commit 6fb4bcbb24
4 changed files with 75 additions and 30 deletions

View File

@@ -44,18 +44,19 @@ def prepare_data(
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
# kv seq len = past kv seq len + seq len (1 during decoding stage)
kv_seq_lengths = past_kv_seq_lengths + 1
return new_k, k_cache, kv_seq_lengths, block_tables
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
@@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
dtype = torch.float16
device = get_current_device()
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
@@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
)
# k_cache_torch = k_cache.clone().detach()
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
target = k_cache[target_block_ids, :, offsets_in_block, :]
source = new_k.squeeze()
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
v_source = new_v.squeeze()
assert target.shape == source.shape
assert torch.equal(target, source)
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
# assert target_torch.shape == source.shape
# assert torch.equal(target_torch, source)
@@ -143,7 +148,7 @@ def benchmark_kvcache_copy(
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
new_k, k_cache, context_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
@@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
)
quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
if provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms