[Inference]Add fused rotary kernel and get cos cache kernel (#5302)

* add fused rotary and get cos cache func

* staged

* fix bugs

* fix bugs
This commit is contained in:
Jianghai
2024-01-24 16:20:42 +08:00
committed by GitHub
parent 3da9993b0d
commit c647e00e3c
6 changed files with 477 additions and 5 deletions

View File

@@ -11,11 +11,12 @@ if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors
from .rms_layernorm import rms_layernorm
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
from .rms_layernorm import rms_layernorm
from .rotary_cache_copy import get_xine_cache
from .softmax import softmax
__all__ = [
@@ -27,4 +28,6 @@ if HAS_TRITON:
"gptq_fused_linear_triton",
"rotary_embedding",
"FDIntermTensors",
"fused_rotary_embedding",
"get_xine_cache",
]

View File

@@ -0,0 +1,182 @@
import torch
import triton
import triton.language as tl
@triton.jit
def fused_rotary_emb(
q,
k,
cos_cache,
sin_cache,
cumsum_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_dim_stride,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_ELEMENTS: tl.constexpr,
):
block_head_index = tl.program_id(0)
block_group_index = tl.program_id(1)
group_token_index = tl.program_id(2)
idx = block_group_index * BLOCK_SIZE + group_token_index
# original seq_idx and pos
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
cos = tl.load(
cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride
) # [1,HEAD_DIM//2]
sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride)
cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_q0 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_q1 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
off_k0 = (
idx * k_token_stride
+ cur_head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
q_0 = tl.load(
q + off_q0,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)
q_1 = tl.load(
q + off_q1,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)
k_0 = tl.load(
k + off_k0,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)
k_1 = tl.load(
k + off_k1,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)
out_q0 = q_0 * cos - q_1 * sin
out_q1 = k_0 * sin + k_1 * cos
out_k0 = q_0 * cos - q_1 * sin
out_k1 = k_0 * sin + k_1 * cos
# concat
tl.store(
q + off_q0,
out_q0,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
)
tl.store(
q + off_q1,
out_q1,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
)
tl.store(
k + off_k0,
out_k0,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
)
@torch.no_grad()
def fused_rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
lengths,
):
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
lengths [num_seqs]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_SIZE = 16
cumsum_lens = torch.cumsum(lengths, dim=0)
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4
q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = q.shape[1]
cos_token_stride = cos.stride(0)
cos_dim_stride = cos.stride(1)
fused_rotary_emb[grid](
q,
k,
cos,
sin,
cumsum_lens,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_dim_stride,
q_total_tokens,
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SIZE=BLOCK_SIZE,
N_ELEMENTS=triton.next_power_of_2(q_total_tokens),
num_warps=num_warps,
)

View File

@@ -98,11 +98,12 @@ def rotary_embedding(
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
cos: cosine for rotary embedding, [total_tokens, head_dim]
sin: sine for rotary embedding, [total_tokens, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
lengths [num_seqs]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 8
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))

View File

@@ -0,0 +1,110 @@
import torch
import triton
import triton.language as tl
@triton.jit
def prefill_cache_kernel(
CaChe,
cumsum_lengths,
output,
cache_stride,
hidden_stride,
total_length,
HIDDEN_DIM: tl.constexpr,
N_ELEMENTS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
idx0 = tl.program_id(axis=0)
idx1 = tl.program_id(axis=1)
idx = idx0 * BLOCK_SIZE + idx1
# original seq_idx and pos
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)
@triton.jit
def decoding_cache_kernel(
CaChe,
lengths,
output,
cache_stride,
hidden_stride,
HIDDEN_DIM: tl.constexpr,
NUM_SEQS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
tl.store(
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
_cache,
mask=idx[:, None] < NUM_SEQS,
)
@torch.no_grad()
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
"""
Transform cos/sin cache into no pad sequence, with two different modes.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
For prefill mode:
cos/sin cache for each sequence is equal to its length.
For decoding mode:
cos/sin cache is only needed for the last token.
"""
_, hidden_dim = cache.shape
num_seqs = lengths.numel()
BLOCK_SIZE = 16
if hidden_dim >= 128:
num_warps = 8
else:
num_warps = 4
cache_stride = cache.stride(0)
hidden_stride = cache.stride(1)
if is_prompts:
total_length = lengths.sum().item()
cumsum_lens = torch.cumsum(lengths, dim=0)
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
prefill_cache_kernel[grid](
cache,
cumsum_lens,
output,
cache_stride,
hidden_stride,
total_length,
HIDDEN_DIM=hidden_dim,
N_ELEMENTS=triton.next_power_of_2(num_seqs),
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
else:
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
nlengths = torch.as_tensor(lengths) - 1
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
decoding_cache_kernel[grid](
cache,
nlengths,
output,
cache_stride,
hidden_stride,
HIDDEN_DIM=hidden_dim,
NUM_SEQS=num_seqs,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return output