[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-04-23 13:09:55 +08:00
committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
9 changed files with 183 additions and 194 deletions

View File

@@ -36,97 +36,91 @@ def rotary_embedding_kernel(
cos_stride,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_TOKENS: tl.constexpr, # token range length
):
block_head_index = tl.program_id(0)
block_token_index = tl.program_id(1)
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
cur_head_idx = tl.program_id(0)
cur_token_block_idx = tl.program_id(1)
tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
off_q0 = (
tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride
+ cur_head_idx * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_q1 = (
tokens_range[:, None, None] * q_token_stride
+ head_range[None, :, None] * q_head_stride
+ cur_head_idx * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_q0 = tl.load(
q + off_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_q1 = tl.load(
q + off_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k0 = tl.load(
k + off_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
other=0.0,
)
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
# concat
tl.store(
q + off_q0,
out_q0,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
q + off_q1,
out_q1,
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k0,
out_k0,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
)
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
k_head_idx = cur_head_idx // KV_GROUP_NUM
off_k0 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
tokens_range[:, None, None] * k_token_stride
+ k_head_idx * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)
loaded_k0 = tl.load(
k + off_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
loaded_k1 = tl.load(
k + off_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
other=0.0,
)
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
tl.store(
k + off_k0,
out_k0,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
tl.store(
k + off_k1,
out_k1,
mask=(tokens_range[:, None, None] < q_total_tokens),
)
@triton.jit
def fused_rotary_embedding_kernel(
@@ -405,108 +399,74 @@ def decoding_fused_rotary_embedding_kernel(
bts_stride,
btb_stride,
block_size,
Q_HEAD_NUM: tl.constexpr,
KV_GROUP_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)
cur_head_idx = tl.program_id(0)
cur_token_idx = tl.program_id(1)
dim_range = tl.arange(0, HEAD_DIM)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
total_dim_range = tl.arange(0, HEAD_DIM)
q_off_base = block_token_index * q_token_stride + block_head_index * q_head_stride
off_q0 = q_off_base + dim_range0 * head_dim_stride
off_q1 = q_off_base + dim_range1 * head_dim_stride
off_base = block_token_index * k_token_stride + block_head_index * k_head_stride
off_k0 = off_base + dim_range0 * head_dim_stride
off_k1 = off_base + dim_range1 * head_dim_stride
off_v = off_base + total_dim_range * head_dim_stride
loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)
loaded_k0 = tl.load(
k + off_k0,
)
loaded_k1 = tl.load(
k + off_k1,
)
loaded_v = tl.load(
v + off_v,
)
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride
off_q0 = off_q + dim_range0 * head_dim_stride
off_q1 = off_q + dim_range1 * head_dim_stride
loaded_q0 = tl.load(q + off_q0)
loaded_q1 = tl.load(q + off_q1)
off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride
loaded_cos = tl.load(cos + off_cos_sin)
loaded_sin = tl.load(sin + off_cos_sin)
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
tl.store(q + off_q0, out_q0)
tl.store(q + off_q1, out_q1)
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
handle_k = cur_head_idx % KV_GROUP_NUM == 0
if handle_k:
cur_k_head_idx = cur_head_idx // KV_GROUP_NUM
off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride
off_k0 = off_kv + dim_range0 * head_dim_stride
off_k1 = off_kv + dim_range1 * head_dim_stride
loaded_k0 = tl.load(k + off_k0)
loaded_k1 = tl.load(k + off_k1)
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos
last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + block_token_index * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size
# NOTE The precondition here is that it's only for unpadded inputs during decoding stage,
# and so that we could directly use the token index as the sequence index
past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1
k_range0 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range0 * cache_d_stride
)
k_range1 = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range1 * cache_d_stride
)
v_range = (
block_ids * cache_b_stride
+ block_head_index * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ total_dim_range * cache_d_stride
)
last_block_idx = past_kv_seq_len // block_size
block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)
offsets_in_last_block = past_kv_seq_len % block_size
k_range0 = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range0 * cache_d_stride
)
k_range1 = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range1 * cache_d_stride
)
tl.store(k_cache + k_range0, out_k0)
tl.store(k_cache + k_range1, out_k1)
tl.store(
v_cache + v_range,
loaded_v,
)
tl.store(
k_cache + k_range0,
out_k0,
)
tl.store(
k_cache + k_range1,
out_k1,
)
# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
off_v = off_kv + dim_range * head_dim_stride
loaded_v = tl.load(v + off_v)
v_range = (
block_ids * cache_b_stride
+ cur_k_head_idx * cache_h_stride
+ offsets_in_last_block * cache_bs_stride
+ dim_range * cache_d_stride
)
tl.store(v_cache + v_range, loaded_v)
def rotary_embedding(
@@ -521,7 +481,7 @@ def rotary_embedding(
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
@@ -530,32 +490,26 @@ def rotary_embedding(
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 512:
if head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4
q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)
k_head_num = k.size(1)
q_token_stride, q_head_stride, head_dim_stride = q.stride()
k_token_stride, k_head_stride, _ = k.stride()
cos_token_stride, cos_stride = cos.stride()
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
assert q_head_num % k_head_num == 0
kv_group_num = q_head_num // k_head_num
k_head_num = q.shape[1]
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None:
grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
q_head_num,
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
)
rotary_embedding_kernel[grid](
@@ -572,9 +526,8 @@ def rotary_embedding(
cos_stride,
q_total_tokens,
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps,
)
@@ -624,23 +577,21 @@ def decoding_fused_rotary_embedding(
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
v: value tensor, [total tokens, head_num, head_dim]
k: key tensor, [total_tokens, kv_head_num, head_dim]
v: value tensor, [total tokens, kv_head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, num_kv_heads, block_size, head_dim]
k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]
v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]
kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]
block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1)
assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 512:
if head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
@@ -653,10 +604,12 @@ def decoding_fused_rotary_embedding(
k_token_stride = k.stride(0)
k_head_stride = k.stride(1)
k_head_num = k.size(1)
kv_group_num = q_head_num // k_head_num
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
grid = (q_head_num, q_total_tokens)
decoding_fused_rotary_embedding_kernel[grid](
q,
k,
@@ -681,7 +634,7 @@ def decoding_fused_rotary_embedding(
block_tables.stride(0),
block_tables.stride(1),
k_cache.size(-2),
Q_HEAD_NUM=q_head_num,
KV_GROUP_NUM=kv_group_num,
HEAD_DIM=head_dim,
num_warps=num_warps,
)