mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[kernel] Revise KVCache copy triton kernel API (#5273)
* [kernel/fix] revise kvcache copy kernel api * fix benchmark
This commit is contained in:
@@ -30,12 +30,12 @@ def prepare_data(
|
||||
dtype=torch.float16,
|
||||
):
|
||||
if same_context_len:
|
||||
# context_lengths in this test records the previous kv seq len
|
||||
# past_kv_seq_lengths in this test records the previous kv seq len
|
||||
# (not incorporating the current input whose seq len is 1)
|
||||
context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
|
||||
num_tokens = torch.sum(past_kv_seq_lengths).item()
|
||||
|
||||
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
|
||||
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
@@ -46,15 +46,18 @@ def prepare_data(
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
|
||||
# Mock allocation on block tables as well as blocked kv caches
|
||||
block_tables = mock_alloc_block_table_and_kvcache(
|
||||
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
mock_alloc_single_token(block_tables, context_lengths, block_size)
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
|
||||
return new_k, k_cache, context_lengths, block_tables
|
||||
# kv seq len = past kv seq len + seq len (1 during decoding stage)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
|
||||
return new_k, k_cache, kv_seq_lengths, block_tables
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||
@@ -80,7 +83,7 @@ def test_copy_kv_to_caches(
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
@@ -91,25 +94,24 @@ def test_copy_kv_to_caches(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
|
||||
|
||||
for seq_i in range(bsz):
|
||||
ki = new_k[seq_i]
|
||||
ki = ki.squeeze()
|
||||
context_len_i = context_lengths[seq_i]
|
||||
target_block_id = block_tables[seq_i, context_len_i // block_size]
|
||||
offsets_in_block = context_len_i % block_size
|
||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_id, :, :, offsets_in_block]
|
||||
orig = new_k[seq_i].squeeze(dim=0)
|
||||
assert torch.equal(orig, target)
|
||||
|
||||
|
||||
BATCH = 4
|
||||
BATCH = 16
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["PAST_KVLEN"],
|
||||
x_vals=[2**i - 1 for i in range(8, 13)],
|
||||
x_names=["KV_SEQ_LEN"],
|
||||
x_vals=[2**i for i in range(8, 13)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_copy_func", "triton_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func"],
|
||||
@@ -127,7 +129,7 @@ def benchmark_kvcache_copy(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
||||
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
@@ -138,7 +140,7 @@ def benchmark_kvcache_copy(
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len"
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
|
||||
new_k, k_cache, context_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
@@ -147,7 +149,7 @@ def benchmark_kvcache_copy(
|
||||
block_size,
|
||||
max_seq_len // block_size,
|
||||
same_context_len,
|
||||
PAST_KVLEN,
|
||||
KV_SEQ_LEN,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
@@ -164,5 +166,5 @@ def benchmark_kvcache_copy(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, False)
|
||||
# benchmark_kvcache_copy.run(save_path=".")
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||
# benchmark_kvcache_copy.run(save_path=".", print_data=True)
|
||||
|
Reference in New Issue
Block a user