mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -16,6 +17,7 @@ if HAS_TRITON:
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
"""
|
||||
if triton.__version__ < "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q,
|
||||
@@ -131,29 +133,47 @@ if HAS_TRITON:
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = (
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
(cur_batch_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
else:
|
||||
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel_2(
|
||||
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
|
||||
Out,
|
||||
kv_group_num,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Alibi,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
|
||||
if kv_group_num is not None:
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
@@ -166,7 +186,11 @@ if HAS_TRITON:
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :] * stride_qd
|
||||
)
|
||||
if kv_group_num is None or kv_group_num == 1:
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
@@ -191,8 +215,11 @@ if HAS_TRITON:
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
@@ -220,8 +247,11 @@ if HAS_TRITON:
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
@@ -229,7 +259,11 @@ if HAS_TRITON:
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :] * stride_od
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
@@ -249,7 +283,7 @@ if HAS_TRITON:
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
|
||||
if triton.__version__ < "2.1.0":
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
_context_flash_attention_kernel[grid](
|
||||
@@ -286,20 +320,26 @@ if HAS_TRITON:
|
||||
)
|
||||
else:
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
alibi,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
None,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
@@ -307,7 +347,7 @@ if HAS_TRITON:
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -327,7 +367,7 @@ if HAS_TRITON:
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# num_warps = 4
|
||||
|
||||
|
||||
if triton.__version__ < "2.1.0":
|
||||
_context_flash_attention_kernel[grid](
|
||||
q,
|
||||
@@ -337,7 +377,7 @@ if HAS_TRITON:
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
None,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
@@ -362,32 +402,33 @@ if HAS_TRITON:
|
||||
)
|
||||
else:
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
_context_flash_attention_kernel_2[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
None,
|
||||
b_start_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
kv_group_num,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,)
|
||||
|
||||
return
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
return
|
||||
|
@@ -1,8 +1,10 @@
|
||||
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
|
||||
import torch
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
|
||||
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
|
||||
@@ -10,41 +12,36 @@ except:
|
||||
|
||||
|
||||
if HAS_LIGHTLLM_KERNEL:
|
||||
|
||||
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
|
||||
BLOCK_SEQ = 256
|
||||
batch_size = infer_state.batch_size
|
||||
max_len_in_batch = infer_state.max_len_in_batch
|
||||
|
||||
|
||||
calcu_shape1 = (batch_size, q_head_num, head_dim)
|
||||
|
||||
if getattr(infer_state, 'mid_o', None) is None:
|
||||
infer_state.mid_o = torch.empty([batch_size,
|
||||
q_head_num,
|
||||
max_len_in_batch // BLOCK_SEQ + 1,
|
||||
head_dim],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
infer_state.mid_o_logexpsum = torch.empty([batch_size,
|
||||
q_head_num,
|
||||
max_len_in_batch // BLOCK_SEQ + 1],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
if getattr(infer_state, "mid_o", None) is None:
|
||||
infer_state.mid_o = torch.empty(
|
||||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
infer_state.mid_o_logexpsum = torch.empty(
|
||||
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
mid_o = infer_state.mid_o
|
||||
mid_o_logexpsum = infer_state.mid_o_logexpsum
|
||||
|
||||
flash_decode_stage1(q.view(calcu_shape1),
|
||||
cache_k,
|
||||
cache_v,
|
||||
infer_state.block_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.max_len_in_batch,
|
||||
mid_o,
|
||||
mid_o_logexpsum,
|
||||
BLOCK_SEQ)
|
||||
flash_decode_stage2(mid_o,
|
||||
mid_o_logexpsum,
|
||||
infer_state.seq_len,
|
||||
o_tensor.view(calcu_shape1),
|
||||
BLOCK_SEQ)
|
||||
flash_decode_stage1(
|
||||
q.view(calcu_shape1),
|
||||
cache_k,
|
||||
cache_v,
|
||||
infer_state.block_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.max_len_in_batch,
|
||||
mid_o,
|
||||
mid_o_logexpsum,
|
||||
BLOCK_SEQ,
|
||||
)
|
||||
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
|
||||
|
@@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
@@ -26,8 +27,8 @@ if HAS_TRITON:
|
||||
X_GATE2,
|
||||
X_UP,
|
||||
Y,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
@@ -41,9 +42,9 @@ if HAS_TRITON:
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
|
||||
# Write output
|
||||
@@ -58,8 +59,8 @@ if HAS_TRITON:
|
||||
X_GATE2_GRAD,
|
||||
X_UP_GRAD,
|
||||
Y_GRAD,
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
@@ -76,10 +77,10 @@ if HAS_TRITON:
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
|
||||
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0)
|
||||
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0)
|
||||
x_up = tl.load(X_UP + cols, mask=mask, other=0.0)
|
||||
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0)
|
||||
|
||||
# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
|
||||
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
|
||||
@@ -147,14 +148,9 @@ if HAS_TRITON:
|
||||
# restore setting
|
||||
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
|
||||
# enqueue kernel
|
||||
_llama_act_combine_forward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
y,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_llama_act_combine_forward[(M,)](
|
||||
x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||
)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@@ -166,20 +162,25 @@ if HAS_TRITON:
|
||||
|
||||
# init grad
|
||||
y_grad = grad_outputs[0]
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
|
||||
x_gate2), torch.empty_like(x_up)
|
||||
x_gate1_grad, x_gate2_grad, x_up_grad = (
|
||||
torch.empty_like(x_gate1),
|
||||
torch.empty_like(x_gate2),
|
||||
torch.empty_like(x_up),
|
||||
)
|
||||
|
||||
# enqueue kernel
|
||||
_llama_act_combine_backward[(M,)](x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
x_gate2_grad,
|
||||
x_up_grad,
|
||||
y_grad,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
_llama_act_combine_backward[(M,)](
|
||||
x_gate1,
|
||||
x_gate2,
|
||||
x_up,
|
||||
x_gate1_grad,
|
||||
x_gate2_grad,
|
||||
x_up_grad,
|
||||
y_grad,
|
||||
x_up.stride(-2),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
|
||||
return x_gate_grad, x_up_grad, None, None
|
||||
|
@@ -13,10 +13,18 @@ except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_bloom_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import (
|
||||
token_att_fwd as lightllm_llama_token_att_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import (
|
||||
token_att_fwd2 as lightllm_llama_token_att_fwd2,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import (
|
||||
token_softmax_fwd as lightllm_llama_token_softmax_fwd,
|
||||
)
|
||||
|
||||
HAS_TRITON_TOKEN_ATTENTION = True
|
||||
except ImportError:
|
||||
@@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards:
|
||||
|
||||
if triton.__version__ == "2.0.0":
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
lightllm_llama_token_softmax_fwd(
|
||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||
)
|
||||
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
|
||||
lightllm_llama_token_att_fwd2(
|
||||
|
Reference in New Issue
Block a user