mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ except ImportError:
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
HEAD_DIM = 4
|
||||
HEAD_DIM = 128
|
||||
BATCH = 16
|
||||
BLOCK_SIZE = 32
|
||||
SAME_LEN = True
|
||||
@@ -25,9 +25,9 @@ configs = [
|
||||
x_names=["KV_SEQ_LEN"],
|
||||
x_vals=[2**i for i in range(8, 13)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||
line_vals=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
|
||||
line_names=["torch_copy_func", "triton_copy_func", "triton_new_kcache_layout", "cuda_copy_func"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
|
||||
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
|
||||
@@ -45,7 +45,7 @@ def benchmark_kvcache_copy(
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
):
|
||||
dtype = torch.float32
|
||||
dtype = torch.float16
|
||||
device = get_current_device()
|
||||
|
||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||
@@ -63,11 +63,18 @@ def benchmark_kvcache_copy(
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||
if provider == "torch_copy_func":
|
||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||
elif provider == "triton_copy_func":
|
||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||
elif provider == "triton_new_kcache_layout":
|
||||
# NOTE New kcache layout (num_blocks, num_kv_heads, head_dim // x, block_size, x) to be applied
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
k_cache_shape = (bsz * max_seq_len // block_size, num_kv_heads, HEAD_DIM // x, block_size, x)
|
||||
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device) # update k_cache layout
|
||||
fn = lambda: copy_kv_to_blocked_cache(
|
||||
new_k, new_v, k_cache, v_cache, context_lengths, block_tables, use_new_kcache_layout=True
|
||||
)
|
||||
elif provider == "cuda_copy_func":
|
||||
_, _, k_cache, _, _, _, _, _, _ = prepare_data_new_kcache_layout(
|
||||
bsz, num_kv_heads, block_size, max_seq_len // block_size, context_lengths - 1, device, dtype
|
||||
|
Reference in New Issue
Block a user