mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 18:00:43 +00:00
[Inference] Kernel Fusion, fused copy kv cache into rotary embedding (#5336)
* revise rotary embedding * remove useless print * adapt
This commit is contained in:
parent
1336838a91
commit
df0aa49585
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -126,12 +128,161 @@ def rotary_embedding_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_rotary_embedding_kernel(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
kv_cache,
|
||||||
|
BLOCK_TABLES,
|
||||||
|
context_lengths,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
cacheb_stride,
|
||||||
|
cacheh_stride,
|
||||||
|
cachebs_stride,
|
||||||
|
cached_stride,
|
||||||
|
bts_stride,
|
||||||
|
btb_stride,
|
||||||
|
block_size,
|
||||||
|
q_total_tokens,
|
||||||
|
Q_HEAD_NUM: tl.constexpr,
|
||||||
|
K_HEAD_NUM: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_HEAD: tl.constexpr,
|
||||||
|
BLOCK_TOKENS: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_head_index = tl.program_id(0)
|
||||||
|
block_token_index = tl.program_id(1)
|
||||||
|
|
||||||
|
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
||||||
|
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||||
|
|
||||||
|
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||||
|
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||||
|
|
||||||
|
off_q0 = (
|
||||||
|
tokens_range[:, None, None] * q_token_stride
|
||||||
|
+ head_range[None, :, None] * q_head_stride
|
||||||
|
+ dim_range0[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
off_q1 = (
|
||||||
|
tokens_range[:, None, None] * q_token_stride
|
||||||
|
+ head_range[None, :, None] * q_head_stride
|
||||||
|
+ dim_range1[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
off_k0 = (
|
||||||
|
tokens_range[:, None, None] * k_token_stride
|
||||||
|
+ head_range[None, :, None] * k_head_stride
|
||||||
|
+ dim_range0[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
off_k1 = (
|
||||||
|
tokens_range[:, None, None] * k_token_stride
|
||||||
|
+ head_range[None, :, None] * k_head_stride
|
||||||
|
+ dim_range1[None, None, :] * head_dim_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_q0 = tl.load(
|
||||||
|
q + off_q0,
|
||||||
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
loaded_q1 = tl.load(
|
||||||
|
q + off_q1,
|
||||||
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k0 = tl.load(
|
||||||
|
k + off_k0,
|
||||||
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_k1 = tl.load(
|
||||||
|
k + off_k1,
|
||||||
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
|
||||||
|
|
||||||
|
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||||
|
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||||
|
|
||||||
|
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
|
||||||
|
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
|
||||||
|
|
||||||
|
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
||||||
|
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
|
||||||
|
|
||||||
|
past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
|
||||||
|
|
||||||
|
last_block_idx = past_kv_seq_len // block_size
|
||||||
|
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
|
||||||
|
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
|
||||||
|
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
||||||
|
|
||||||
|
kv_range0 = (
|
||||||
|
block_ids[:, None, None, None] * cacheb_stride
|
||||||
|
+ head_range[None, :, None, None] * cacheh_stride
|
||||||
|
+ offsets_in_last_block[:, None, None, None]
|
||||||
|
+ dim_range0[None, None, None, :] * cached_stride
|
||||||
|
)
|
||||||
|
kv_range1 = (
|
||||||
|
block_ids[:, None, None, None] * cacheb_stride
|
||||||
|
+ head_range[None, :, None, None] * cacheh_stride
|
||||||
|
+ offsets_in_last_block[:, None, None, None]
|
||||||
|
+ dim_range1[None, None, None, :] * cached_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
kv_cache + kv_range0,
|
||||||
|
out_k0[:, :, None, :],
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
kv_cache + kv_range1,
|
||||||
|
out_k1[:, :, None, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# concat
|
||||||
|
tl.store(
|
||||||
|
q + off_q0,
|
||||||
|
out_q0,
|
||||||
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
q + off_q1,
|
||||||
|
out_q1,
|
||||||
|
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
k + off_k0,
|
||||||
|
out_k0,
|
||||||
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
k + off_k1,
|
||||||
|
out_k1,
|
||||||
|
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rotary_embedding(
|
def rotary_embedding(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
|
k_cache: Optional[torch.Tensor] = None,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
kv_lengths: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -139,7 +290,9 @@ def rotary_embedding(
|
|||||||
k: key tensor, [total_tokens, head_num, head_dim]
|
k: key tensor, [total_tokens, head_num, head_dim]
|
||||||
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
||||||
sin: sine for rotary embedding, [max_position_len, head_dim]
|
sin: sine for rotary embedding, [max_position_len, head_dim]
|
||||||
lengths [num_seqs]
|
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
|
||||||
|
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
|
||||||
|
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
|
||||||
"""
|
"""
|
||||||
q_total_tokens, q_head_num, head_dim = q.shape
|
q_total_tokens, q_head_num, head_dim = q.shape
|
||||||
assert q.size(0) == k.size(0)
|
assert q.size(0) == k.size(0)
|
||||||
@ -165,26 +318,56 @@ def rotary_embedding(
|
|||||||
|
|
||||||
cos_token_stride = cos.stride(0)
|
cos_token_stride = cos.stride(0)
|
||||||
cos_stride = cos.stride(1)
|
cos_stride = cos.stride(1)
|
||||||
|
if k_cache == None:
|
||||||
rotary_embedding_kernel[grid](
|
rotary_embedding_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
q_token_stride,
|
q_token_stride,
|
||||||
q_head_stride,
|
q_head_stride,
|
||||||
k_token_stride,
|
k_token_stride,
|
||||||
k_head_stride,
|
k_head_stride,
|
||||||
head_dim_stride,
|
head_dim_stride,
|
||||||
cos_token_stride,
|
cos_token_stride,
|
||||||
cos_stride,
|
cos_stride,
|
||||||
q_total_tokens,
|
q_total_tokens,
|
||||||
Q_HEAD_NUM=q_head_num,
|
Q_HEAD_NUM=q_head_num,
|
||||||
K_HEAD_NUM=k_head_num,
|
K_HEAD_NUM=k_head_num,
|
||||||
HEAD_DIM=head_dim,
|
HEAD_DIM=head_dim,
|
||||||
BLOCK_HEAD=BLOCK_HEAD,
|
BLOCK_HEAD=BLOCK_HEAD,
|
||||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
fused_rotary_embedding_kernel[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
k_cache,
|
||||||
|
block_tables,
|
||||||
|
kv_lengths,
|
||||||
|
q_token_stride,
|
||||||
|
q_head_stride,
|
||||||
|
k_token_stride,
|
||||||
|
k_head_stride,
|
||||||
|
head_dim_stride,
|
||||||
|
cos_token_stride,
|
||||||
|
cos_stride,
|
||||||
|
k_cache.stride(0),
|
||||||
|
k_cache.stride(1),
|
||||||
|
k_cache.stride(2),
|
||||||
|
k_cache.stride(3),
|
||||||
|
block_tables.stride(0),
|
||||||
|
block_tables.stride(1),
|
||||||
|
k_cache.size(-2),
|
||||||
|
q_total_tokens,
|
||||||
|
Q_HEAD_NUM=q_head_num,
|
||||||
|
K_HEAD_NUM=k_head_num,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
BLOCK_HEAD=BLOCK_HEAD,
|
||||||
|
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
@ -4,6 +4,7 @@ from packaging import version
|
|||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.triton import rotary_embedding
|
from colossalai.kernel.triton import rotary_embedding
|
||||||
|
from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
@ -47,6 +48,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||||||
assert torch.allclose(embd_x0, embd_stimulated_x)
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
# create data
|
# create data
|
||||||
|
block_size = 32
|
||||||
|
max_num_blocks_per_seq = 4
|
||||||
q_shape = (TOTAL_TOKENS, H, D)
|
q_shape = (TOTAL_TOKENS, H, D)
|
||||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
k_shape = (TOTAL_TOKENS, H, D)
|
k_shape = (TOTAL_TOKENS, H, D)
|
||||||
@ -54,13 +57,35 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, H, block_size, D)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
v = torch.randn_like(k)
|
||||||
|
v_cache = torch.zeros_like(k_cache)
|
||||||
|
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||||
|
)
|
||||||
|
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||||
|
new_q = torch.randn_like(new_k)
|
||||||
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
|
||||||
q_ref = torch_rotary_emb(q, cos, sin)
|
rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
|
||||||
k_ref = torch_rotary_emb(k, cos, sin)
|
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||||
rotary_embedding(q, k, cos, sin)
|
assert torch.allclose(new_k, k_ref, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
|
# check one by one
|
||||||
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
|
for seq_i in range(BATCH_SIZE):
|
||||||
|
ki = new_k[seq_i]
|
||||||
|
ki = ki.squeeze()
|
||||||
|
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
|
||||||
|
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
|
target = k_cache[target_block_id, :, offsets_in_block, :]
|
||||||
|
orig = new_k[seq_i].squeeze(dim=0)
|
||||||
|
assert torch.equal(orig, target)
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
BATCH = 16
|
||||||
|
@ -53,10 +53,10 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
|||||||
assert torch.allclose(cos, cos_ref)
|
assert torch.allclose(cos, cos_ref)
|
||||||
assert torch.allclose(sin, sin_ref)
|
assert torch.allclose(sin, sin_ref)
|
||||||
# decoding
|
# decoding
|
||||||
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
ncos_ref, nsin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||||
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
|
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
|
||||||
assert torch.allclose(cos, ncos_ref)
|
assert torch.allclose(cos, ncos_ref)
|
||||||
assert torch.allclose(sin, sin_ref)
|
assert torch.allclose(sin, nsin_ref)
|
||||||
|
|
||||||
|
|
||||||
configs = [
|
configs = [
|
||||||
|
Loading…
Reference in New Issue
Block a user