mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -2,7 +2,11 @@ import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
@@ -68,11 +72,17 @@ def benchmark_rotary_emb(
|
||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
|
||||
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")
|
||||
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
_ = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
new_v = torch.randn_like(new_k)
|
||||
@@ -94,12 +104,12 @@ def benchmark_rotary_emb(
|
||||
)
|
||||
elif provider == "no_fused_cuda_rotary_emb_func":
|
||||
fn = lambda: [
|
||||
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, new_k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||
]
|
||||
elif provider == "fused_cuda_rotary_emb_func":
|
||||
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
Reference in New Issue
Block a user