mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
try:
|
||||
import triton
|
||||
|
||||
HAS_TRITON = True
|
||||
|
||||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||
@@ -11,8 +12,14 @@ try:
|
||||
from .token_attention_kernel import token_attention_fwd
|
||||
|
||||
__all__ = [
|
||||
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
|
||||
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
|
||||
"llama_context_attn_fwd",
|
||||
"bloom_context_attn_fwd",
|
||||
"softmax",
|
||||
"layer_norm",
|
||||
"rmsnorm_forward",
|
||||
"copy_kv_cache_to_dest",
|
||||
"rotary_embedding_fwd",
|
||||
"token_attention_fwd",
|
||||
]
|
||||
|
||||
except ImportError:
|
||||
|
@@ -1,8 +1,11 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -10,28 +13,42 @@ except ImportError:
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
'''
|
||||
"""
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
B_Start_Loc, B_Seqlen,
|
||||
TMP,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
stride_tmp_b, stride_tmp_h, stride_tmp_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_tmp_b,
|
||||
stride_tmp_h,
|
||||
stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
@@ -40,13 +57,18 @@ if HAS_TRITON:
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# get batch info
|
||||
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
|
||||
load_p_ptrs = (
|
||||
Q
|
||||
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
@@ -56,7 +78,7 @@ if HAS_TRITON:
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
@@ -64,8 +86,11 @@ if HAS_TRITON:
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
@@ -95,21 +120,25 @@ if HAS_TRITON:
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
||||
BLOCK = 128
|
||||
@@ -129,17 +158,31 @@ if HAS_TRITON:
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
b_start_loc, b_seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
alibi,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
# manually setting this blcok num, we can use tuning config to futher speed-up
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
tmp.stride(0),
|
||||
tmp.stride(1),
|
||||
tmp.stride(2),
|
||||
# manually setting this blcok num, we can use tuning config to futher speed-up
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
@@ -147,7 +190,7 @@ if HAS_TRITON:
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
@@ -166,19 +209,34 @@ if HAS_TRITON:
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# num_warps = 4
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale, b_start_loc, b_seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
tmp.stride(0),
|
||||
tmp.stride(1),
|
||||
tmp.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
return
|
||||
|
@@ -3,25 +3,28 @@ import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _fwd_copy_kv_cache_dest(
|
||||
kv_cache_ptr, dest_index_ptr,
|
||||
kv_cache_ptr,
|
||||
dest_index_ptr,
|
||||
out,
|
||||
stride_k_bs,
|
||||
stride_k_h,
|
||||
stride_k_bs,
|
||||
stride_k_h,
|
||||
stride_k_d,
|
||||
stride_o_bs,
|
||||
stride_o_h,
|
||||
stride_o_bs,
|
||||
stride_o_h,
|
||||
stride_o_d,
|
||||
head_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
):
|
||||
cur_index = tl.program_id(0)
|
||||
offs_h = tl.arange(0, BLOCK_HEAD)
|
||||
@@ -31,15 +34,14 @@ if HAS_TRITON:
|
||||
|
||||
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
|
||||
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
|
||||
|
||||
|
||||
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
|
||||
o_ptrs = out + dest_index * stride_o_bs + o_offsets
|
||||
|
||||
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
|
||||
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
||||
seq_len = dest_index_ptr.shape[0]
|
||||
@@ -47,16 +49,18 @@ if HAS_TRITON:
|
||||
head_dim = k_ptr.shape[2]
|
||||
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
|
||||
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
|
||||
|
||||
|
||||
num_warps = 2
|
||||
|
||||
_fwd_copy_kv_cache_dest[(seq_len,)](
|
||||
k_ptr, dest_index_ptr, out,
|
||||
k_ptr.stride(0),
|
||||
k_ptr.stride(1),
|
||||
k_ptr,
|
||||
dest_index_ptr,
|
||||
out,
|
||||
k_ptr.stride(0),
|
||||
k_ptr.stride(1),
|
||||
k_ptr.stride(2),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
head_num,
|
||||
BLOCK_DMODEL=head_dim,
|
||||
@@ -65,5 +69,3 @@ if HAS_TRITON:
|
||||
num_stages=2,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
@@ -3,6 +3,7 @@ import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -14,13 +15,13 @@ if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
@@ -32,15 +33,15 @@ if HAS_TRITON:
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.where(cols < N, x - mean, 0.)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
x = tl.where(cols < N, x - mean, 0.0)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
@@ -50,7 +51,7 @@ if HAS_TRITON:
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd
|
||||
y = x_hat * w + b
|
||||
# Write output
|
||||
@@ -71,13 +72,7 @@ if HAS_TRITON:
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg,
|
||||
y,
|
||||
weight,
|
||||
bias,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_layer_norm_fwd_fused[(M,)](
|
||||
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||
)
|
||||
return y
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -9,9 +9,10 @@ except ImportError:
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
"""
|
||||
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
'''
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def qkv_gemm_4d_kernel(
|
||||
a_ptr,
|
||||
@@ -34,12 +35,12 @@ if HAS_TRITON:
|
||||
stride_cn,
|
||||
scale,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M : tl.constexpr = 64,
|
||||
BLOCK_SIZE_N : tl.constexpr = 32,
|
||||
BLOCK_SIZE_K : tl.constexpr = 32,
|
||||
GROUP_SIZE_M : tl.constexpr = 8,
|
||||
BLOCK_SIZE_M: tl.constexpr = 64,
|
||||
BLOCK_SIZE_N: tl.constexpr = 32,
|
||||
BLOCK_SIZE_K: tl.constexpr = 32,
|
||||
GROUP_SIZE_M: tl.constexpr = 8,
|
||||
):
|
||||
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
|
||||
r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
|
||||
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
|
||||
Args:
|
||||
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
|
||||
@@ -53,21 +54,21 @@ if HAS_TRITON:
|
||||
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
|
||||
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
|
||||
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
|
||||
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
|
||||
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
|
||||
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
|
||||
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
|
||||
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
|
||||
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
|
||||
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
|
||||
BLOCK_SIZE_K : tiling size for K-dimension of a and b
|
||||
GROUP_SIZE_M : group size for reducing cache miss, more details:
|
||||
GROUP_SIZE_M : group size for reducing cache miss, more details:
|
||||
"""
|
||||
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
batch = tl.program_id(axis = 0)
|
||||
head = tl.program_id(axis = 1)
|
||||
pid = tl.program_id(axis = 2)
|
||||
batch = tl.program_id(axis=0)
|
||||
head = tl.program_id(axis=1)
|
||||
pid = tl.program_id(axis=2)
|
||||
|
||||
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
@@ -77,33 +78,38 @@ if HAS_TRITON:
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
|
||||
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
|
||||
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
|
||||
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
|
||||
a_ptrs = (
|
||||
a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
)
|
||||
b_ptrs = (
|
||||
b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
|
||||
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.)
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
if scale > 0:
|
||||
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
|
||||
|
||||
|
||||
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
|
||||
stride_cn * offs_accumu_n[None, :])
|
||||
c_ptrs = (
|
||||
c_ptr
|
||||
+ batch * stride_cb
|
||||
+ head * stride_ch
|
||||
+ stride_cm * offs_accumu_m[:, None]
|
||||
+ stride_cn * offs_accumu_n[None, :]
|
||||
)
|
||||
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=accumulator_mask)
|
||||
|
@@ -3,17 +3,19 @@ import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this kernel function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
|
||||
'''
|
||||
"""
|
||||
this kernel function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_fwd_fused(
|
||||
X, # pointer to the input
|
||||
@@ -32,7 +34,7 @@ if HAS_TRITON:
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
@@ -41,13 +43,12 @@ if HAS_TRITON:
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
x_hat = x * rstd
|
||||
y = x_hat * w
|
||||
# Write output
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
|
||||
def rmsnorm_forward(x, weight, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
@@ -66,7 +67,5 @@ if HAS_TRITON:
|
||||
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
|
||||
num_warps = 8
|
||||
# enqueue kernel
|
||||
_rms_norm_fwd_fused[(M,)](x_arg, y, weight,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
_rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
return y
|
||||
|
@@ -29,19 +29,29 @@ def _rotary_kernel(
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||||
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
|
||||
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||||
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
|
||||
off_q0 = (
|
||||
current_seq_range[:, None, None] * q_bs_stride
|
||||
+ current_head_range[None, :, None] * q_h_stride
|
||||
+ dim_range0[None, None, :] * q_d_stride
|
||||
)
|
||||
off_q1 = (
|
||||
current_seq_range[:, None, None] * q_bs_stride
|
||||
+ current_head_range[None, :, None] * q_h_stride
|
||||
+ dim_range1[None, None, :] * q_d_stride
|
||||
)
|
||||
|
||||
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
||||
|
||||
q0 = tl.load(q + off_q0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0)
|
||||
q1 = tl.load(q + off_q1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0)
|
||||
q0 = tl.load(
|
||||
q + off_q0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0,
|
||||
)
|
||||
q1 = tl.load(
|
||||
q + off_q1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
@@ -49,12 +59,16 @@ def _rotary_kernel(
|
||||
out0 = q0 * cos - q1 * sin
|
||||
out1 = q0 * sin + q1 * cos
|
||||
|
||||
tl.store(q + off_q0,
|
||||
out0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||||
tl.store(q + off_q1,
|
||||
out1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||||
tl.store(
|
||||
q + off_q0,
|
||||
out0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
)
|
||||
tl.store(
|
||||
q + off_q1,
|
||||
out1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
@@ -1,9 +1,8 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -13,9 +12,10 @@ if HAS_TRITON:
|
||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from .softmax import softmax_kernel
|
||||
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
def self_attention_forward_without_fusion(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
|
||||
):
|
||||
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
Args:
|
||||
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
@@ -65,7 +65,7 @@ if HAS_TRITON:
|
||||
score_output.stride(2),
|
||||
score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
@@ -79,7 +79,6 @@ if HAS_TRITON:
|
||||
n_rows, n_cols = score_output.shape
|
||||
|
||||
if n_rows <= 350000:
|
||||
|
||||
block_size = max(triton.next_power_of_2(n_cols), 2)
|
||||
num_warps = 4
|
||||
if block_size >= 4096:
|
||||
@@ -142,15 +141,9 @@ if HAS_TRITON:
|
||||
)
|
||||
return output.view(batches, -1, d_model)
|
||||
|
||||
def self_attention_compute_using_triton(qkv,
|
||||
input_mask,
|
||||
layer_past,
|
||||
alibi,
|
||||
scale,
|
||||
head_size,
|
||||
triangular=False,
|
||||
use_flash=False):
|
||||
|
||||
def self_attention_compute_using_triton(
|
||||
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
|
||||
):
|
||||
assert qkv.is_contiguous()
|
||||
assert alibi is None, "current triton self-attention does not support alibi"
|
||||
batches = qkv.shape[0]
|
||||
@@ -158,8 +151,8 @@ if HAS_TRITON:
|
||||
num_of_heads = d_model // head_size
|
||||
|
||||
q = qkv[:, :, :d_model]
|
||||
k = qkv[:, :, d_model:d_model * 2]
|
||||
v = qkv[:, :, d_model * 2:]
|
||||
k = qkv[:, :, d_model : d_model * 2]
|
||||
v = qkv[:, :, d_model * 2 :]
|
||||
q = q.view(batches, -1, num_of_heads, head_size)
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
@@ -1,39 +1,42 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
"""
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
r"""the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
@@ -43,17 +46,16 @@ if HAS_TRITON:
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
@@ -67,30 +69,31 @@ if HAS_TRITON:
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
softmax_kernel[grid](
|
||||
output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps
|
||||
)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),)
|
||||
|
||||
if block_size >= 4096:
|
||||
pass
|
||||
elif block_size >= 2048:
|
||||
pass
|
||||
|
||||
softmax_kernel[grid](
|
||||
output_ptr=output,
|
||||
input_ptr=input,
|
||||
row_stride=input.stride(0),
|
||||
n_rows=num_rows,
|
||||
n_cols=num_cols,
|
||||
mask_ptr=mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M=32,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
@@ -1,12 +1,12 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -15,10 +15,28 @@ except ImportError:
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
|
||||
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
|
||||
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
|
||||
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
def _token_attn_1_kernel(
|
||||
Q,
|
||||
K,
|
||||
sm_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
q_batch_stride,
|
||||
q_head_stride,
|
||||
q_head_dim_stride,
|
||||
k_batch_stride,
|
||||
k_head_stride,
|
||||
k_head_dim_stride,
|
||||
attn_head_stride,
|
||||
attn_batch_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
@@ -40,9 +58,11 @@ if HAS_TRITON:
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0)
|
||||
k_loc = tl.load(
|
||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
@@ -52,11 +72,29 @@ if HAS_TRITON:
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
|
||||
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
|
||||
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
|
||||
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
def _token_attn_1_alibi_kernel(
|
||||
Q,
|
||||
K,
|
||||
sm_scale,
|
||||
alibi,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
q_batch_stride,
|
||||
q_head_stride,
|
||||
q_head_dim_stride,
|
||||
k_batch_stride,
|
||||
k_head_stride,
|
||||
k_head_dim_stride,
|
||||
attn_head_stride,
|
||||
attn_batch_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
@@ -79,9 +117,11 @@ if HAS_TRITON:
|
||||
alibi_m = tl.load(alibi + current_head)
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0)
|
||||
k_loc = tl.load(
|
||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
@@ -92,14 +132,9 @@ if HAS_TRITON:
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_fwd_1(q,
|
||||
k,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
alibi=None):
|
||||
def token_attn_fwd_1(
|
||||
q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
|
||||
@@ -168,9 +203,17 @@ if HAS_TRITON:
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
|
||||
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
def _token_attn_softmax_fwd(
|
||||
softmax_logics,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
softmax_prob_out,
|
||||
logics_head_dim_stride,
|
||||
logics_batch_stride,
|
||||
prob_head_dim_stride,
|
||||
prob_batch_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
@@ -178,20 +221,26 @@ if HAS_TRITON:
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
|
||||
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
other=-float('inf')).to(tl.float32)
|
||||
row = tl.load(
|
||||
softmax_logics
|
||||
+ current_head * logics_head_dim_stride
|
||||
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
other=-float("inf"),
|
||||
).to(tl.float32)
|
||||
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
|
||||
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
||||
softmax_output,
|
||||
mask=col_offsets < current_batch_seq_len)
|
||||
tl.store(
|
||||
softmax_prob_out
|
||||
+ current_head * prob_head_dim_stride
|
||||
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
||||
softmax_output,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -220,11 +269,27 @@ if HAS_TRITON:
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
|
||||
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
|
||||
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
|
||||
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
def _token_attn_2_kernel(
|
||||
Prob,
|
||||
V,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
kv_cache_loc_b_stride,
|
||||
kv_cache_loc_s_stride,
|
||||
prob_head_dim_stride,
|
||||
prob_batch_stride,
|
||||
v_batch_stride,
|
||||
v_head_stride,
|
||||
v_head_dim_stride,
|
||||
attn_out_batch_stride,
|
||||
attn_out_head_stride,
|
||||
attn_out_head_dim_stride,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
@@ -232,7 +297,6 @@ if HAS_TRITON:
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = current_batch_seq_len
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
|
||||
@@ -242,19 +306,29 @@ if HAS_TRITON:
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
for start_n in range(0, current_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
|
||||
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
p_value = tl.load(
|
||||
Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
v_loc = tl.load(
|
||||
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
v_value = tl.load(
|
||||
V + v_offs + v_loc[:, None] * v_batch_stride,
|
||||
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
||||
|
||||
acc = acc.to(tl.float16)
|
||||
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
|
||||
off_o = (
|
||||
current_batch * attn_out_batch_stride
|
||||
+ current_head * attn_out_head_stride
|
||||
+ offs_d * attn_out_head_dim_stride
|
||||
)
|
||||
out_ptrs = attn_out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
return
|
||||
@@ -296,15 +370,9 @@ if HAS_TRITON:
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=None):
|
||||
def token_attention_fwd(
|
||||
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
|
||||
):
|
||||
head_num = k.shape[1]
|
||||
batch_size = kv_cache_seq_len.shape[0]
|
||||
calcu_shape1 = (batch_size, head_num, k.shape[2])
|
||||
@@ -312,21 +380,24 @@ if HAS_TRITON:
|
||||
|
||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||
|
||||
token_attn_fwd_1(q.view(calcu_shape1),
|
||||
k,
|
||||
att_m_tensor,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=alibi)
|
||||
token_attn_fwd_1(
|
||||
q.view(calcu_shape1),
|
||||
k,
|
||||
att_m_tensor,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=alibi,
|
||||
)
|
||||
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
|
||||
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
|
||||
max_len_in_batch)
|
||||
token_attn_fwd_2(
|
||||
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
||||
)
|
||||
|
||||
prob = None
|
||||
|
||||
|
Reference in New Issue
Block a user