[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-05-03 17:20:45 +08:00
committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
10 changed files with 428 additions and 206 deletions

View File

@@ -10,6 +10,7 @@ from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
@@ -75,6 +76,7 @@ def prepare_data(
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("q_len", [1, 5])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_flash_decoding(
bsz: int,
block_size: int,
@@ -84,7 +86,15 @@ def test_flash_decoding(
same_context_len: bool,
q_len: int,
use_alibi_slopes: bool,
use_new_kcache_layout: bool,
):
if use_new_kcache_layout and use_alibi_slopes:
# TODO(yuanheng-zhao): Since the alibi kernel is pretty similar to the original one,
# the code (alibi kernel) will be refactored later to avoid code duplication, when
# the whole triton flow with new k cache layout has been supported and tested.
# And tests for the alibi kernel using new kcache layout will be added then.
pytest.skip("Alibi kernel does not support new kcache layout yet.")
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
@@ -127,9 +137,14 @@ def test_flash_decoding(
q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
)
block_tables = block_tables.to(device=device)
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_kv_len_in_b + block_size - 1) // block_size
@@ -165,6 +180,7 @@ def test_flash_decoding(
sm_scale=sm_scale,
kv_group_num=kv_group_num,
q_len=q_len,
use_new_kcache_layout=use_new_kcache_layout,
) # [bsz * q_len, num_heads, head_dim]
assert out_torch.shape == out_triton.shape
@@ -178,4 +194,4 @@ def test_flash_decoding(
if __name__ == "__main__":
test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)
test_flash_decoding(16, 32, 32, 16, 1, True, 1, use_alibi_slopes=False, use_new_kcache_layout=True)

View File

@@ -4,7 +4,11 @@ from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
from tests.test_infer.test_ops.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
)
try:
import triton # noqa
@@ -30,6 +34,7 @@ def prepare_data(
n=1,
device="cuda",
dtype=torch.float16,
use_new_kcache_layout=False,
):
assert max_seq_len > n, "max_seq_len must be greater than n"
@@ -44,9 +49,14 @@ def prepare_data(
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
if use_new_kcache_layout:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
else:
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)
new_k = torch.randn((bsz, n, num_kv_heads, head_dim), dtype=dtype, device=device)
@@ -66,8 +76,15 @@ def prepare_data(
@pytest.mark.parametrize("num_kv_heads", [16])
@pytest.mark.parametrize("same_context_len", [True, False])
@pytest.mark.parametrize("n_tokens", [1, 5])
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_copy_kv_to_caches(
bsz: int, block_size: int, max_num_blocks_per_seq: int, num_kv_heads: int, same_context_len: bool, n_tokens: int
bsz: int,
block_size: int,
max_num_blocks_per_seq: int,
num_kv_heads: int,
same_context_len: bool,
n_tokens: int,
use_new_kcache_layout: bool,
):
torch.manual_seed(123)
torch.cuda.empty_cache()
@@ -89,6 +106,7 @@ def test_copy_kv_to_caches(
n_tokens,
device=device,
dtype=dtype,
use_new_kcache_layout=use_new_kcache_layout,
)
k_source = new_k.view(-1, new_k.size(-2), new_k.size(-1))
v_source = new_v.view(-1, new_v.size(-2), new_v.size(-1))
@@ -98,7 +116,9 @@ def test_copy_kv_to_caches(
offsets_in_block = past_kv_seq_lengths % block_size
# Copy k (or v) to k (or v) cache
copy_k_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens)
copy_k_to_blocked_cache(
new_k, k_cache, kv_seq_lengths, block_tables, n=n_tokens, use_new_kcache_layout=use_new_kcache_layout
)
# Reshape target k from k cache to compare if matching with original tensor
# Mainly to handle cases of n_tokens > 1
k_target = []
@@ -110,26 +130,39 @@ def test_copy_kv_to_caches(
while tokens_left > 0:
tokens_to_fill = min(block_size - offset, tokens_left)
curr_block_id = block_table[curr_kv_len // block_size]
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
if use_new_kcache_layout:
k_target.append(k_cache[curr_block_id, :, :, offset : offset + tokens_to_fill, :])
else:
k_target.append(k_cache[curr_block_id, :, offset : offset + tokens_to_fill, :])
curr_kv_len += tokens_to_fill
tokens_left -= tokens_to_fill
offset = 0
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
if use_new_kcache_layout:
k_target = torch.concat(k_target, dim=2).permute(2, 0, 1, 3).contiguous()
k_target = k_target.reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = torch.concat(k_target, dim=1).transpose(0, 1).contiguous() # [bsz * n, num_kv_heads, head_dim]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
if n_tokens == 1:
# Copy k and v to k/v caches
k_cache = k_cache_copy
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
k_target = k_cache_copy[target_block_ids, :, offsets_in_block, :]
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
copy_kv_to_blocked_cache(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables, use_new_kcache_layout=use_new_kcache_layout
)
if use_new_kcache_layout:
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
k_target = k_target.contiguous().reshape(bsz * n_tokens, num_kv_heads, HEAD_DIM)
else:
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, True)
test_copy_kv_to_caches(4, 32, 8, 16, True, n_tokens=1)

View File

@@ -4,7 +4,10 @@ from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
from tests.test_infer.test_ops.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)
try:
import triton # noqa
@@ -36,7 +39,8 @@ def torch_rotary_emb(x, cos, sin):
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
@pytest.mark.parametrize("use_new_kcache_layout", [True, False])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
@@ -57,28 +61,40 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
v_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
if use_new_kcache_layout:
x = 16 // torch.tensor([], dtype=dtype).element_size()
kcache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, D // x, block_size, x)
k_cache = torch.zeros(size=kcache_shape, dtype=dtype, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v3(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
else:
k_cache = torch.zeros_like(v_cache)
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
decoding_fused_rotary_embedding(new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths)
decoding_fused_rotary_embedding(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths, use_new_kcache_layout
)
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32)
test_rotary_emb(4, 64, 32, 64, torch.float32, use_new_kcache_layout=True)