[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-05-03 17:20:45 +08:00
committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
10 changed files with 428 additions and 206 deletions

View File

@@ -24,18 +24,20 @@ configs = [
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
"triton_rotary_emb_func",
"triton_fused_rotary_emb_func",
"triton_fused_rotary_emb_func_new_kcache_layout",
"cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
],
line_names=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
"triton_rotary_emb_func",
"triton_fused_rotary_emb_func",
"triton_fused_rotary_emb_func(new layout)",
"cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
@@ -91,31 +93,44 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
if provider == "no_fused_triton_rotary_emb_func":
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
),
]
elif provider == "fused_triton_rotary_emb_func":
elif provider == "triton_fused_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
)
elif provider == "no_fused_cuda_rotary_emb_func":
elif provider == "triton_fused_rotary_emb_func_new_kcache_layout":
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
block_tables = block_tables.to(device="cuda")
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True
)
elif provider == "cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
elif provider == "cuda_fused_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
)
else:
raise ValueError("Undefined provider")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
return ms, min_ms, max_ms
if __name__ == "__main__":