mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference] Adapt to Fused rotary (#5348)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix
This commit is contained in:
@@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel(
|
||||
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
||||
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim
|
||||
|
||||
past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
|
||||
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1
|
||||
|
||||
last_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
|
||||
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
|
||||
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
||||
|
||||
kv_range0 = (
|
||||
@@ -274,6 +274,122 @@ def fused_rotary_embedding_kernel(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_rotary_embedding_kernel_v2(
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
kv_cache,
|
||||
BLOCK_TABLES,
|
||||
context_lengths,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_stride,
|
||||
cacheb_stride,
|
||||
cacheh_stride,
|
||||
cachebs_stride,
|
||||
cached_stride,
|
||||
bts_stride,
|
||||
btb_stride,
|
||||
block_size,
|
||||
q_total_tokens,
|
||||
Q_HEAD_NUM: tl.constexpr,
|
||||
K_HEAD_NUM: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
block_head_index = tl.program_id(0)
|
||||
if block_head_index >= Q_HEAD_NUM:
|
||||
return
|
||||
block_token_index = tl.program_id(1)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
|
||||
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
|
||||
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
|
||||
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride
|
||||
|
||||
loaded_q0 = tl.load(
|
||||
q + off_q0,
|
||||
)
|
||||
loaded_q1 = tl.load(
|
||||
q + off_q1,
|
||||
)
|
||||
|
||||
loaded_k0 = tl.load(
|
||||
k + off_k0,
|
||||
)
|
||||
|
||||
loaded_k1 = tl.load(
|
||||
k + off_k1,
|
||||
)
|
||||
|
||||
off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride
|
||||
|
||||
loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
||||
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
|
||||
|
||||
out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
|
||||
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos
|
||||
|
||||
out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
|
||||
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim
|
||||
|
||||
past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1
|
||||
|
||||
last_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
|
||||
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride
|
||||
|
||||
kv_range0 = (
|
||||
block_ids * cacheb_stride
|
||||
+ block_head_index * cacheh_stride
|
||||
+ offsets_in_last_block
|
||||
+ dim_range0 * cached_stride
|
||||
)
|
||||
kv_range1 = (
|
||||
block_ids * cacheb_stride
|
||||
+ block_head_index * cacheh_stride
|
||||
+ offsets_in_last_block
|
||||
+ dim_range1 * cached_stride
|
||||
)
|
||||
|
||||
tl.store(
|
||||
kv_cache + kv_range0,
|
||||
out_k0,
|
||||
)
|
||||
tl.store(
|
||||
kv_cache + kv_range1,
|
||||
out_k1,
|
||||
)
|
||||
|
||||
# concat
|
||||
tl.store(
|
||||
q + off_q0,
|
||||
out_q0,
|
||||
)
|
||||
tl.store(
|
||||
q + off_q1,
|
||||
out_q1,
|
||||
)
|
||||
tl.store(
|
||||
k + off_k0,
|
||||
out_k0,
|
||||
)
|
||||
tl.store(
|
||||
k + off_k1,
|
||||
out_k1,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rotary_embedding(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -297,12 +413,13 @@ def rotary_embedding(
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 4
|
||||
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))
|
||||
|
||||
if head_dim >= 256:
|
||||
if head_dim >= 1024:
|
||||
num_warps = 32
|
||||
elif head_dim >= 128:
|
||||
elif head_dim >= 512:
|
||||
num_warps = 16
|
||||
elif head_dim >= 256:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
@@ -318,6 +435,10 @@ def rotary_embedding(
|
||||
cos_token_stride = cos.stride(0)
|
||||
cos_stride = cos.stride(1)
|
||||
if k_cache == None:
|
||||
grid = lambda META: (
|
||||
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
|
||||
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
|
||||
)
|
||||
rotary_embedding_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
@@ -339,7 +460,8 @@ def rotary_embedding(
|
||||
num_warps=num_warps,
|
||||
)
|
||||
else:
|
||||
fused_rotary_embedding_kernel[grid](
|
||||
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
|
||||
fused_rotary_embedding_kernel_v2[grid](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
@@ -365,8 +487,6 @@ def rotary_embedding(
|
||||
Q_HEAD_NUM=q_head_num,
|
||||
K_HEAD_NUM=k_head_num,
|
||||
HEAD_DIM=head_dim,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return
|
||||
|
Reference in New Issue
Block a user