[kernel] Add KV cache copy kernel during decoding (#5261)

* add kv copy triton kernel during decoding stage

* add pytest and fix kernel

* fix test utilities

* revise kernel config

* add benchmark for kvcache copy
This commit is contained in:
Yuanheng Zhao
2024-01-15 17:37:20 +08:00
committed by GitHub
parent 1ded7e81ef
commit fa85e02b3b
5 changed files with 288 additions and 2 deletions

View File

@@ -31,7 +31,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
1, 2, 0
)
elif type == "decoding":
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
@@ -314,4 +314,4 @@ class PagedAttention:
):
return self.pad_decoding_forward(
q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), k_cache, v_cache, lengths, block_tables
)
)