[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-05-03 17:20:45 +08:00
committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
10 changed files with 428 additions and 206 deletions

View File

@@ -6,6 +6,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
@@ -29,9 +30,9 @@ configs = [
x_vals=[2**i for i in range(8, 14)],
# x_vals=[x for x in range(256, 8192, 256)],
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["Torch", "Triton"],
styles=[("red", "-"), ("blue", "-")],
line_vals=["torch", "triton", "triton_new_kcache_layout"],
line_names=["Torch", "Triton", "Triton New KCache Layout"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
@@ -62,6 +63,14 @@ def bench_kernel(
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
)
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
sm_scale = 1.0 / (HEAD_DIM**0.5)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
@@ -81,19 +90,11 @@ def bench_kernel(
HEAD_DIM,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton":
elif provider == "triton":
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (HEAD_DIM**0.5)
fn = lambda: flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
@@ -111,6 +112,29 @@ def bench_kernel(
kv_group_num=kv_group_num,
) # [bsz, 1, num_heads, head_dim]
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
elif provider == "triton_new_kcache_layout":
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
fn = lambda: flash_decoding_attention(
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
# refer to attention forward in modeling.
q.squeeze(2),
k_cache,
v_cache,
kv_lengths,
block_tables,
block_size,
max_seq_len_in_b,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
kv_group_num=kv_group_num,
use_new_kcache_layout=True,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms

View File

@@ -24,18 +24,20 @@ configs = [
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
"triton_rotary_emb_func",
"triton_fused_rotary_emb_func",
"triton_fused_rotary_emb_func_new_kcache_layout",
"cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
],
line_names=[
"no_fused_triton_rotary_emb_func",
"fused_triton_rotary_emb_func",
"no_fused_cuda_rotary_emb_func",
"fused_cuda_rotary_emb_func",
"triton_rotary_emb_func",
"triton_fused_rotary_emb_func",
"triton_fused_rotary_emb_func(new layout)",
"cuda_rotary_emb_func",
"cuda_fused_rotary_emb_func",
],
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
styles=[("red", "-"), ("blue", "-"), ("purple", "-"), ("green", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
args={"num_kv_heads": 16},
@@ -91,31 +93,44 @@ def benchmark_rotary_emb(
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
if provider == "no_fused_triton_rotary_emb_func":
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
),
]
elif provider == "fused_triton_rotary_emb_func":
elif provider == "triton_fused_rotary_emb_func":
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
)
elif provider == "no_fused_cuda_rotary_emb_func":
elif provider == "triton_fused_rotary_emb_func_new_kcache_layout":
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
block_tables = block_tables.to(device="cuda")
fn = lambda: decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout=True
)
elif provider == "cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
elif provider == "cuda_fused_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
)
else:
raise ValueError("Undefined provider")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
return ms, min_ms, max_ms
if __name__ == "__main__":

View File

@@ -14,7 +14,7 @@ except ImportError:
inference_ops = InferenceOpsLoader().load()
HEAD_DIM = 4
HEAD_DIM = 128
BATCH = 16
BLOCK_SIZE = 32
SAME_LEN = True
@@ -25,9 +25,9 @@ configs = [
x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
@@ -45,7 +45,7 @@ def benchmark_kvcache_copy(
num_kv_heads: int,
same_context_len: bool,
):
dtype = torch.float32
dtype = torch.float16
device = get_current_device()
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
@@ -63,11 +63,18 @@ def benchmark_kvcache_copy(
)
quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
elif provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
elif provider == "triton_new_kcache_layout":
# NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout
fn = lambda: copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True
)
elif provider == "cuda_copy_func":
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype