mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Inference] Kernel Fusion, fused copy kv cache into rotary embedding (#5336)
* revise rotary embedding * remove useless print * adapt
This commit is contained in:
@@ -4,6 +4,7 @@ from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import rotary_embedding
|
||||
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -47,6 +48,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||
|
||||
# create data
|
||||
block_size = 32
|
||||
max_num_blocks_per_seq = 4
|
||||
q_shape = (TOTAL_TOKENS, H, D)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (TOTAL_TOKENS, H, D)
|
||||
@@ -54,13 +57,35 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
|
||||
q_ref = torch_rotary_emb(q, cos, sin)
|
||||
k_ref = torch_rotary_emb(k, cos, sin)
|
||||
rotary_embedding(q, k, cos, sin)
|
||||
rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
|
||||
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
|
||||
# check one by one
|
||||
for seq_i in range(BATCH_SIZE):
|
||||
ki = new_k[seq_i]
|
||||
ki = ki.squeeze()
|
||||
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
||||
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
target = k_cache[target_block_id, :, offsets_in_block, :]
|
||||
orig = new_k[seq_i].squeeze(dim=0)
|
||||
assert torch.equal(orig, target)
|
||||
|
||||
|
||||
BATCH = 16
|
||||
|
@@ -53,10 +53,10 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
||||
assert torch.allclose(cos, cos_ref)
|
||||
assert torch.allclose(sin, sin_ref)
|
||||
# decoding
|
||||
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
|
||||
assert torch.allclose(cos, ncos_ref)
|
||||
assert torch.allclose(sin, sin_ref)
|
||||
assert torch.allclose(sin, nsin_ref)
|
||||
|
||||
|
||||
configs = [
|
||||
|
Reference in New Issue
Block a user