[Inference]Fused kv copy into rotary calculation (#5383)

* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
This commit is contained in:
Jianghai
2024-02-21 11:31:48 +08:00
committed by GitHub
parent b21aac5bae
commit 730103819d
8 changed files with 391 additions and 498 deletions

View File

@@ -13,7 +13,7 @@ if HAS_TRITON:
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
from .rms_layernorm import rms_layernorm
from .rotary_cache_copy import get_xine_cache
from .softmax import softmax
@@ -28,4 +28,5 @@ if HAS_TRITON:
"rotary_embedding",
"fused_rotary_embedding",
"get_xine_cache",
"decoding_fused_rotary_embedding",
]

View File

@@ -45,21 +45,21 @@ def _copy_to_kvcache_seqlen1_kernel(
k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)
offsets_kvcache = (
offsets_kcache = (
block_id * stride_cachekb
+ cur_kv_head_idx * stride_cachekh
+ offsets_in_last_block * stride_cachekbs
+ offsets_dmodel * stride_cachekd
)
offsets_kvcache = (
offsets_vcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)
tl.store(KCache + offsets_kvcache, k)
tl.store(VCache + offsets_kvcache, v)
tl.store(KCache + offsets_kcache, k)
tl.store(VCache + offsets_vcache, v)
return

View File

@@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel(
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1
last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
kv_range0 = (
@@ -274,6 +274,241 @@ def fused_rotary_embedding_kernel(
)
@triton.jit
def fused_rotary_embedding_kernel_v2(
q,
k,
cos,
sin,
kv_cache,
BLOCK_TABLES,
context_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_stride,
cacheb_stride,
cacheh_stride,
cachebs_stride,
cached_stride,
bts_stride,
btb_stride,
block_size,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load(
k + off_k0,
)
loaded_k1 = tl.load(
k + off_k1,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
kv_range0 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range0 * cached_stride
)
kv_range1 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range1 * cached_stride
)
tl.store(
kv_cache + kv_range0,
out_k0,
)
tl.store(
kv_cache + kv_range1,
out_k1,
)
# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
@triton.jit
def decoding_fused_rotary_embedding_kernel(
q,
k,
v,
cos,
sin,
k_cache,
v_cache,
BLOCK_TABLES,
context_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_stride,
cache_b_stride,
cache_h_stride,
cache_bs_stride,
cache_d_stride,
bts_stride,
btb_stride,
block_size,
Q_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
total_dim_range = tl.arange(0, HEAD_DIM)
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride
off_q0 = q_off_base + dim_range0 * head_dim_stride
off_q1 = q_off_base + dim_range1 * head_dim_stride
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
off_k0 = off_base + dim_range0 * head_dim_stride
off_k1 = off_base + dim_range1 * head_dim_stride
off_v = off_base + total_dim_range * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load(
k + off_k0,
)
loaded_k1 = tl.load(
k + off_k1,
)
loaded_v = tl.load(
v + off_v,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin)
loaded_sin = tl.load(sin + off_cos_sin)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size
k_range0 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range0 * cache_d_stride
)
k_range1 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range1 * cache_d_stride
)
v_range = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ total_dim_range * cache_d_stride
)
tl.store(
v_cache + v_range,
loaded_v,
)
tl.store(
k_cache + k_range0,
out_k0,
)
tl.store(
k_cache + k_range1,
out_k1,
)
# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
def rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
@@ -297,12 +532,13 @@ def rotary_embedding(
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))
if head_dim >= 256:
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 128:
elif head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4
@@ -318,6 +554,10 @@ def rotary_embedding(
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None:
grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
)
rotary_embedding_kernel[grid](
q,
k,
@@ -339,7 +579,8 @@ def rotary_embedding(
num_warps=num_warps,
)
else:
fused_rotary_embedding_kernel[grid](
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
fused_rotary_embedding_kernel_v2[grid](
q,
k,
cos,
@@ -363,10 +604,85 @@ def rotary_embedding(
k_cache.size(-2),
q_total_tokens,
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps,
)
return
def decoding_fused_rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
kv_lengths: Optional[torch.Tensor] = None,
):
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
v: value tensor, [total tokens, head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim]
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4
q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid](
q,
k,
v,
cos,
sin,
k_cache,
v_cache,
block_tables,
kv_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_stride,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
k_cache.size(-2),
Q_HEAD_NUM=q_head_num,
HEAD_DIM=head_dim,
num_warps=num_warps,
)
return