mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
fix (#5311)
This commit is contained in:
@@ -2,6 +2,22 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
"""
|
||||
# Base autotune if needed
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=4),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":8,},num_warps=8),
|
||||
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":8,},num_warps=8),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=16),
|
||||
triton.Config({'BLOCK_HEAD':4,"BLOCK_TOKENS":4,},num_warps=32),
|
||||
triton.Config({'BLOCK_HEAD':16,"BLOCK_TOKENS":16,},num_warps=4),
|
||||
triton.Config({'BLOCK_HEAD':8,"BLOCK_TOKENS":16,},num_warps=8),
|
||||
],
|
||||
key=['HEAD_DIM','q_total_tokens','Q_HEAD_NUM']
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rotary_embedding_kernel(
|
||||
@@ -26,43 +42,53 @@ def rotary_embedding_kernel(
|
||||
block_head_index = tl.program_id(0)
|
||||
block_token_index = tl.program_id(1)
|
||||
|
||||
rotary_data = q
|
||||
HEAD_NUM = Q_HEAD_NUM
|
||||
head_stride = q_head_stride
|
||||
token_stride = q_token_stride
|
||||
|
||||
if block_token_index * BLOCK_TOKENS >= q_total_tokens:
|
||||
block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS)
|
||||
rotary_data = k
|
||||
HEAD_NUM = K_HEAD_NUM
|
||||
head_stride = k_head_stride
|
||||
token_stride = k_token_stride
|
||||
|
||||
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
|
||||
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_data0 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
off_q0 = (
|
||||
tokens_range[:, None, None] * q_token_stride
|
||||
+ head_range[None, :, None] * q_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_data1 = (
|
||||
tokens_range[:, None, None] * token_stride
|
||||
+ head_range[None, :, None] * head_stride
|
||||
off_q1 = (
|
||||
tokens_range[:, None, None] * q_token_stride
|
||||
+ head_range[None, :, None] * q_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_k0 = (
|
||||
tokens_range[:, None, None] * k_token_stride
|
||||
+ head_range[None, :, None] * k_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_k1 = (
|
||||
tokens_range[:, None, None] * k_token_stride
|
||||
+ head_range[None, :, None] * k_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
|
||||
loaded_data0 = tl.load(
|
||||
rotary_data + off_data0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
loaded_q0 = tl.load(
|
||||
q + off_q0,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
loaded_data1 = tl.load(
|
||||
rotary_data + off_data1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
loaded_q1 = tl.load(
|
||||
q + off_q1,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
loaded_k0 = tl.load(
|
||||
k + off_k0,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
loaded_k1 = tl.load(
|
||||
k + off_k1,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
@@ -71,19 +97,32 @@ def rotary_embedding_kernel(
|
||||
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
|
||||
|
||||
out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :]
|
||||
out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :]
|
||||
out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]
|
||||
out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]
|
||||
|
||||
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
|
||||
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]
|
||||
|
||||
# concat
|
||||
tl.store(
|
||||
rotary_data + off_data0,
|
||||
out0,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
q + off_q0,
|
||||
out_q0,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
rotary_data + off_data1,
|
||||
out1,
|
||||
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
q + off_q1,
|
||||
out_q1,
|
||||
mask=((head_range[None, :, None] < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
k + off_k0,
|
||||
out_k0,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
k + off_k1,
|
||||
out_k1,
|
||||
mask=((head_range[None, :, None] < K_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
|
||||
)
|
||||
|
||||
|
||||
@@ -105,11 +144,13 @@ def rotary_embedding(
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 8
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||
BLOCK_TOKENS = 4
|
||||
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))
|
||||
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
if head_dim >= 256:
|
||||
num_warps = 32
|
||||
elif head_dim >= 128:
|
||||
num_warps = 16
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
@@ -144,7 +185,6 @@ def rotary_embedding(
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
||||
|
Reference in New Issue
Block a user