[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,5 +1,6 @@
try:
import triton
HAS_TRITON = True
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
@@ -11,8 +12,14 @@ try:
from .token_attention_kernel import token_attention_fwd
__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
"rmsnorm_forward",
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
]
except ImportError:

View File

@@ -1,8 +1,11 @@
import torch
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -10,28 +13,42 @@ except ImportError:
if HAS_TRITON:
'''
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
'''
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
@triton.jit
def _context_flash_attention_kernel(
Q, K, V, sm_scale,
B_Start_Loc, B_Seqlen,
TMP,
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_tmp_b, stride_tmp_h, stride_tmp_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
@@ -40,13 +57,18 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
@@ -56,7 +78,7 @@ if HAS_TRITON:
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
@@ -64,8 +86,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
k = tl.load(
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
@@ -95,21 +120,25 @@ if HAS_TRITON:
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
v = tl.load(
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
BLOCK = 128
@@ -129,17 +158,31 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
q, k, v, sm_scale,
b_start_loc, b_seq_len,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
alibi,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
@@ -147,7 +190,7 @@ if HAS_TRITON:
num_stages=1,
)
return
@torch.no_grad()
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
@@ -166,19 +209,34 @@ if HAS_TRITON:
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
_context_flash_attention_kernel[grid](
q, k, v, sm_scale, b_start_loc, b_seq_len,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
return

View File

@@ -3,25 +3,28 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
@triton.jit
def _fwd_copy_kv_cache_dest(
kv_cache_ptr, dest_index_ptr,
kv_cache_ptr,
dest_index_ptr,
out,
stride_k_bs,
stride_k_h,
stride_k_bs,
stride_k_h,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_bs,
stride_o_h,
stride_o_d,
head_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
BLOCK_HEAD: tl.constexpr,
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
@@ -31,15 +34,14 @@ if HAS_TRITON:
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
o_ptrs = out + dest_index * stride_o_bs + o_offsets
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return
@torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0]
@@ -47,16 +49,18 @@ if HAS_TRITON:
head_dim = k_ptr.shape[2]
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr, dest_index_ptr, out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr,
dest_index_ptr,
out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr.stride(2),
out.stride(0),
out.stride(1),
out.stride(0),
out.stride(1),
out.stride(2),
head_num,
BLOCK_DMODEL=head_dim,
@@ -65,5 +69,3 @@ if HAS_TRITON:
num_stages=2,
)
return

View File

@@ -3,6 +3,7 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -14,13 +15,13 @@ if HAS_TRITON:
@triton.jit
def _layer_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
@@ -32,15 +33,15 @@ if HAS_TRITON:
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.0)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
@@ -50,7 +51,7 @@ if HAS_TRITON:
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
@@ -71,13 +72,7 @@ if HAS_TRITON:
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg,
y,
weight,
bias,
x_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
_layer_norm_fwd_fused[(M,)](
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y

View File

@@ -1,7 +1,7 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -9,9 +9,10 @@ except ImportError:
if HAS_TRITON:
'''
"""
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
'''
"""
@triton.jit
def qkv_gemm_4d_kernel(
a_ptr,
@@ -34,12 +35,12 @@ if HAS_TRITON:
stride_cn,
scale,
# Meta-parameters
BLOCK_SIZE_M : tl.constexpr = 64,
BLOCK_SIZE_N : tl.constexpr = 32,
BLOCK_SIZE_K : tl.constexpr = 32,
GROUP_SIZE_M : tl.constexpr = 8,
BLOCK_SIZE_M: tl.constexpr = 64,
BLOCK_SIZE_N: tl.constexpr = 32,
BLOCK_SIZE_K: tl.constexpr = 32,
GROUP_SIZE_M: tl.constexpr = 8,
):
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
Args:
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
@@ -53,21 +54,21 @@ if HAS_TRITON:
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
BLOCK_SIZE_K : tiling size for K-dimension of a and b
GROUP_SIZE_M : group size for reducing cache miss, more details:
GROUP_SIZE_M : group size for reducing cache miss, more details:
"""
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
batch = tl.program_id(axis = 0)
head = tl.program_id(axis = 1)
pid = tl.program_id(axis = 2)
batch = tl.program_id(axis=0)
head = tl.program_id(axis=1)
pid = tl.program_id(axis=2)
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -77,33 +78,38 @@ if HAS_TRITON:
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
a_ptrs = (
a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
)
b_ptrs = (
b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.)
b = tl.load(b_ptrs, mask=b_mask, other=0.)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0:
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
stride_cn * offs_accumu_n[None, :])
c_ptrs = (
c_ptr
+ batch * stride_cb
+ head * stride_ch
+ stride_cm * offs_accumu_m[:, None]
+ stride_cn * offs_accumu_n[None, :]
)
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
tl.store(c_ptrs, accumulator, mask=accumulator_mask)

View File

@@ -3,17 +3,19 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
'''
"""
this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
"""
@triton.jit
def _rms_norm_fwd_fused(
X, # pointer to the input
@@ -32,7 +34,7 @@ if HAS_TRITON:
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
@@ -41,13 +43,12 @@ if HAS_TRITON:
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rmsnorm_forward(x, weight, eps):
# allocate output
y = torch.empty_like(x)
@@ -66,7 +67,5 @@ if HAS_TRITON:
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
num_warps = 8
# enqueue kernel
_rms_norm_fwd_fused[(M,)](x_arg, y, weight,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
_rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y

View File

@@ -29,19 +29,29 @@ def _rotary_kernel(
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)
q1 = tl.load(q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)
q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
@@ -49,12 +59,16 @@ def _rotary_kernel(
out0 = q0 * cos - q1 * sin
out1 = q0 * sin + q1 * cos
tl.store(q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
tl.store(q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return

View File

@@ -1,9 +1,8 @@
import torch
from torch import nn
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -13,9 +12,10 @@ if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax import softmax_kernel
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
def self_attention_forward_without_fusion(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
):
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
@@ -65,7 +65,7 @@ if HAS_TRITON:
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
@@ -79,7 +79,6 @@ if HAS_TRITON:
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
@@ -142,15 +141,9 @@ if HAS_TRITON:
)
return output.view(batches, -1, d_model)
def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
alibi,
scale,
head_size,
triangular=False,
use_flash=False):
def self_attention_compute_using_triton(
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
):
assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0]
@@ -158,8 +151,8 @@ if HAS_TRITON:
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
v = qkv[:, :, d_model * 2:]
k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)

View File

@@ -1,39 +1,42 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
softmax kernel is modified based on
"""
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
"""
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
r"""the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None:
# load mask into SRAM
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update
# update
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
@@ -43,17 +46,16 @@ if HAS_TRITON:
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
@@ -67,30 +69,31 @@ if HAS_TRITON:
else:
num_warps = 4
if num_rows <= 350000:
if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
softmax_kernel[grid](
output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps
)
else:
grid = lambda meta: ()
grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),)
if block_size >= 4096:
pass
elif block_size >= 2048:
pass
softmax_kernel[grid](
output_ptr=output,
input_ptr=input,
row_stride=input.stride(0),
n_rows=num_rows,
n_cols=num_cols,
mask_ptr=mask,
# currently manually setting up size
BLOCK_M=32,
BLOCK_SIZE=block_size,
)
BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
elif block_size >= 2048:
BLOCK_M = 8
softmax_kernel[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)
return output
return output

View File

@@ -1,12 +1,12 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -15,10 +15,28 @@ except ImportError:
if HAS_TRITON:
@triton.jit
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
def _token_attn_1_kernel(
Q,
K,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
@@ -40,9 +58,11 @@ if HAS_TRITON:
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
@@ -52,11 +72,29 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
def _token_attn_1_alibi_kernel(
Q,
K,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
@@ -79,9 +117,11 @@ if HAS_TRITON:
alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
@@ -92,14 +132,9 @@ if HAS_TRITON:
return
@torch.no_grad()
def token_attn_fwd_1(q,
k,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None):
def token_attn_fwd_1(
q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None
):
BLOCK = 32
# shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
@@ -168,9 +203,17 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
BLOCK_SIZE: tl.constexpr):
def _token_attn_softmax_fwd(
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
logics_head_dim_stride,
logics_batch_stride,
prob_head_dim_stride,
prob_batch_stride,
BLOCK_SIZE: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
@@ -178,20 +221,26 @@ if HAS_TRITON:
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float('inf')).to(tl.float32)
row = tl.load(
softmax_logics
+ current_head * logics_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len)
tl.store(
softmax_prob_out
+ current_head * prob_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len,
)
return
@torch.no_grad()
@@ -220,11 +269,27 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
def _token_attn_2_kernel(
Prob,
V,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
prob_head_dim_stride,
prob_batch_stride,
v_batch_stride,
v_head_stride,
v_head_dim_stride,
attn_out_batch_stride,
attn_out_head_stride,
attn_out_head_dim_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
@@ -232,7 +297,6 @@ if HAS_TRITON:
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
@@ -242,19 +306,29 @@ if HAS_TRITON:
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0)
p_value = tl.load(
Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_loc = tl.load(
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0,
)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = acc.to(tl.float16)
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
off_o = (
current_batch * attn_out_batch_stride
+ current_head * attn_out_head_stride
+ offs_d * attn_out_head_dim_stride
)
out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc)
return
@@ -296,15 +370,9 @@ if HAS_TRITON:
return
@torch.no_grad()
def token_attention_fwd(q,
k,
v,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None):
def token_attention_fwd(
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
@@ -312,21 +380,24 @@ if HAS_TRITON:
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
token_attn_fwd_1(q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi)
token_attn_fwd_1(
q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi,
)
prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
max_len_in_batch)
token_attn_fwd_2(
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
)
prob = None