mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
@@ -75,6 +76,7 @@ def prepare_data(
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("q_len", [1, 5])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_flash_decoding(
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
@@ -84,7 +86,15 @@ def test_flash_decoding(
|
||||
same_context_len: bool,
|
||||
q_len: int,
|
||||
use_alibi_slopes: bool,
|
||||
use_new_kcache_layout: bool,
|
||||
):
|
||||
if use_new_kcache_layout and use_alibi_slopes:
|
||||
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
|
||||
# the code (alibi kernel) will be refactored later to avoid code duplication, when
|
||||
# the whole triton flow with new k cache layout has been supported and tested.
|
||||
# And tests for the alibi kernel using new kcache layout will be added then.
|
||||
pytest.skip("Alibi kernel does not support new kcache layout yet.")
|
||||
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
@@ -127,9 +137,14 @@ def test_flash_decoding(
|
||||
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||
)
|
||||
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
if use_new_kcache_layout:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
# The maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
|
||||
@@ -165,6 +180,7 @@ def test_flash_decoding(
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
q_len=q_len,
|
||||
use_new_kcache_layout=use_new_kcache_layout,
|
||||
) # [bsz * q_len, num_heads, head_dim]
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
@@ -178,4 +194,4 @@ def test_flash_decoding(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)
|
||||
test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)
|
||||
|
@@ -4,7 +4,11 @@ from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -30,6 +34,7 @@ def prepare_data(
|
||||
n=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
use_new_kcache_layout=False,
|
||||
):
|
||||
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||
|
||||
@@ -44,9 +49,14 @@ def prepare_data(
|
||||
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
|
||||
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
if use_new_kcache_layout:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
|
||||
)
|
||||
block_tables = block_tables.to(device=device)
|
||||
|
||||
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
|
||||
@@ -66,8 +76,15 @@ def prepare_data(
|
||||
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||
@pytest.mark.parametrize("n_tokens", [1, 5])
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_copy_kv_to_caches(
|
||||
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
|
||||
bsz: int,
|
||||
block_size: int,
|
||||
max_num_blocks_per_seq: int,
|
||||
num_kv_heads: int,
|
||||
same_context_len: bool,
|
||||
n_tokens: int,
|
||||
use_new_kcache_layout: bool,
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -89,6 +106,7 @@ def test_copy_kv_to_caches(
|
||||
n_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_new_kcache_layout=use_new_kcache_layout,
|
||||
)
|
||||
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
|
||||
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
|
||||
@@ -98,7 +116,9 @@ def test_copy_kv_to_caches(
|
||||
offsets_in_block = past_kv_seq_lengths % block_size
|
||||
|
||||
# Copy k (or v) to k (or v) cache
|
||||
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
|
||||
copy_k_to_blocked_cache(
|
||||
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
|
||||
)
|
||||
# Reshape target k from k cache to compare if matching with original tensor
|
||||
# Mainly to handle cases of n_tokens > 1
|
||||
k_target = []
|
||||
@@ -110,26 +130,39 @@ def test_copy_kv_to_caches(
|
||||
while tokens_left > 0:
|
||||
tokens_to_fill = min(block_size - offset, tokens_left)
|
||||
curr_block_id = block_table[curr_kv_len // block_size]
|
||||
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
||||
if use_new_kcache_layout:
|
||||
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
|
||||
else:
|
||||
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
|
||||
curr_kv_len += tokens_to_fill
|
||||
tokens_left -= tokens_to_fill
|
||||
offset = 0
|
||||
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
||||
|
||||
if use_new_kcache_layout:
|
||||
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
|
||||
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
||||
else:
|
||||
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
|
||||
if n_tokens == 1:
|
||||
# Copy k and v to k/v caches
|
||||
k_cache = k_cache_copy
|
||||
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
copy_kv_to_blocked_cache(
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
|
||||
)
|
||||
|
||||
if use_new_kcache_layout:
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
|
||||
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
|
||||
else:
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
assert k_target.shape == k_source.shape
|
||||
assert torch.equal(k_target, k_source)
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
assert v_target.shape == v_source.shape
|
||||
assert torch.equal(v_target, v_source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)
|
||||
|
@@ -4,7 +4,10 @@ from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
)
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
@@ -36,7 +39,8 @@ def torch_rotary_emb(x, cos, sin):
|
||||
@pytest.mark.parametrize("H", [32])
|
||||
@pytest.mark.parametrize("D", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
|
||||
@@ -57,28 +61,40 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (TOTAL_TOKENS, H, D)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
new_v = torch.randn_like(new_k)
|
||||
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
|
||||
|
||||
if use_new_kcache_layout:
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
|
||||
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
else:
|
||||
k_cache = torch.zeros_like(v_cache)
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||
|
||||
decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths)
|
||||
decoding_fused_rotary_embedding(
|
||||
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
|
||||
)
|
||||
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)
|
||||
|
Reference in New Issue
Block a user