mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy * add TODO in test_kvcache_copy.py
This commit is contained in:
@@ -44,18 +44,19 @@ def prepare_data(
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
k_cache, _, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
|
||||
return new_k, k_cache, kv_seq_lengths, block_tables
|
||||
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
@@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
|
||||
)
|
||||
# k_cache_torch = k_cache.clone().detach()
|
||||
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
source = new_k.squeeze()
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_source = new_k.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
assert target.shape == source.shape
|
||||
assert torch.equal(target, source)
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
assert v_target.shape == v_source.shape
|
||||
assert torch.equal(v_target, v_source)
|
||||
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||
# assert target_torch.shape == source.shape
|
||||
# assert torch.equal(target_torch, source)
|
||||
@@ -143,7 +148,7 @@ def benchmark_kvcache_copy(
|
||||
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
@@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||
if provider == "torch_copy_func":
|
||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||
if provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||
return ms, min_ms, max_ms
|
||||
|
Reference in New Issue
Block a user