mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[Feature] The first PR to Add TP inference engine, kv-cache manager and related kernels for our inference system (#4577)
* [infer] Infer/llama demo (#4503)
* add
* add infer example
* finish
* finish
* stash
* fix
* [Kernels] add inference token attention kernel (#4505)
* add token forward
* fix tests
* fix comments
* add try import triton
* add adapted license
* add tests check
* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* combine codes (#4509)
* [feature] add KV cache manager for llama & bloom inference (#4495)
* add kv cache memory manager
* add stateinfo during inference
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* file dir change
* [Bug FIx] import llama context ops fix (#4524)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* fix
* add ops into init.py
* add
* [Infer] Add TPInferEngine and fix file path (#4532)
* add engine for TP inference
* move file path
* update path
* fix TPInferEngine
* remove unused file
* add engine test demo
* revise TPInferEngine
* fix TPInferEngine, add test
* fix
* Add Inference test for llama (#4508)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [infer] Add Bloom inference policy and replaced methods (#4512)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)
This reverts commit 17cfa57140.
* [Doc] Add colossal inference doc (#4549)
* create readme
* add readme.md
* fix typos
* [infer] Add Bloom inference policy and replaced methods (#4553)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* trivial
* Fix Bugs In Llama Model Forward (#4550)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
* bug fix: fix bugs about infer_state.is_context_stage
* remove pollcies
* fix: delete unused code
* fix: delete unused code
* remove unused coda
* fix conflict
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [doc] add colossal inference fig (#4554)
* create readme
* add readme.md
* fix typos
* upload fig
* [NFC] fix docstring for colossal inference (#4555)
Fix docstring and comments in kv cache manager and bloom modeling
* fix docstring in llama modeling (#4557)
* [Infer] check import vllm (#4559)
* change import vllm
* import apply_rotary_pos_emb
* change import location
* [DOC] add installation req (#4561)
* add installation req
* fix
* slight change
* remove empty
* [Feature] rms-norm transfer into inference llama.py (#4563)
* add installation req
* fix
* slight change
* remove empty
* add rmsnorm polciy
* add
* clean codes
* [infer] Fix tp inference engine (#4564)
* fix engine prepare data
* add engine test
* use bloom for testing
* revise on test
* revise on test
* reset shardformer llama (#4569)
* [infer] Fix engine - tensors on different devices (#4570)
* fix diff device in engine
* [codefactor] Feature/colossal inference (#4579)
* code factors
* remove
* change coding (#4581)
* [doc] complete README of colossal inference (#4585)
* complete fig
* Update README.md
* [doc]update readme (#4586)
* update readme
* Update README.md
* bug fix: fix bus in llama and bloom (#4588)
* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* [Kernel]Rmsnorm fix (#4598)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* add triton rmsnorm
* delete vllm kernel flag
* [Bug Fix]Fix bugs in llama (#4601)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* bug fix: remove rotary_positions_ids
---------
Co-authored-by: cuiqing.li <lixx3527@gmail.com>
* [kernel] Add triton layer norm & replace norm for bloom (#4609)
* add layernorm for inference
* add test for layernorm kernel
* add bloom layernorm replacement policy
* trivial: path
* [Infer] Bug fix rotary embedding in llama (#4608)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* [bench] Add bloom inference benchmark (#4621)
* add bloom benchmark
* readme - update benchmark res
* trivial - uncomment for testing (#4622)
* [Infer] add check triton and cuda version for tests (#4627)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* Update sharder.py (#4629)
* [Inference] Hot fix some bugs and typos (#4632)
* fix
* fix test
* fix conflicts
* [typo]Comments fix (#4633)
* fallback
* fix commnets
* bug fix: fix some bugs in test_llama and test_bloom (#4635)
* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* delete benchmark and fix infer bugs
* delete benchmark for tests
* delete useless code
* delete bechmark function in utils
* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)
* [Fix] revise TPInferEngine methods and inference tests
* fix llama/bloom infer benchmarks
* fix infer tests
* trivial fix: benchmakrs
* trivial
* trivial: rm print
* modify utils filename for infer ops test (#4657)
* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
* fix engine funcs
* TPInferEngine: receive shard config in init
* benchmarks: revise TPInferEngine init
* benchmarks: remove pytest decorator
* trivial fix
* use small model for tests
* [NFC] use args for infer benchmarks (#4674)
* revise infer default (#4683)
* [Fix] optimize/shard model in TPInferEngine init (#4684)
* remove using orig model in engine
* revise inference tests
* trivial: rename
---------
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
|
||||
from .triton import softmax
|
||||
from .triton import copy_kv_cache_to_dest
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"MultiHeadAttention",
|
||||
"llama_context_attn_fwd",
|
||||
"bloom_context_attn_fwd",
|
||||
"softmax",
|
||||
"copy_kv_cache_to_dest",
|
||||
]
|
||||
|
||||
5
colossalai/kernel/triton/__init__.py
Normal file
5
colossalai/kernel/triton/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
from .fused_layernorm import layer_norm
|
||||
from .rms_norm import rmsnorm_forward
|
||||
from .softmax import softmax
|
||||
184
colossalai/kernel/triton/context_attention.py
Normal file
184
colossalai/kernel/triton/context_attention.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import torch
|
||||
import math
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
||||
'''
|
||||
@triton.jit
|
||||
def _context_flash_attention_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
B_Start_Loc, B_Seqlen,
|
||||
TMP,
|
||||
alibi_ptr,
|
||||
Out,
|
||||
stride_qbs, stride_qh, stride_qd,
|
||||
stride_kbs, stride_kh, stride_kd,
|
||||
stride_vbs, stride_vh, stride_vd,
|
||||
stride_obs, stride_oh, stride_od,
|
||||
stride_tmp_b, stride_tmp_h, stride_tmp_s,
|
||||
# suggtest set-up 64, 128, 256, 512
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
|
||||
batch_id = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# get batch info
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
||||
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
||||
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
||||
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
||||
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_m = tl.load(alibi_ptr + cur_head)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if alibi_ptr is not None:
|
||||
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
||||
qk -= alibi_loc * alibi_m
|
||||
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs)
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
b_start_loc, b_seq_len,
|
||||
tmp,
|
||||
alibi,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
# manually setting this blcok num, we can use tuning config to futher speed-up
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk, "context process only supports equal query, key, value length"
|
||||
assert Lk == Lv, "context process only supports equal query, key, value length"
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / math.sqrt(Lk)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
|
||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# num_warps = 4
|
||||
_context_flash_attention_kernel[grid](
|
||||
q, k, v, sm_scale, b_start_loc, b_seq_len,
|
||||
tmp,
|
||||
None,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2),
|
||||
k.stride(0), k.stride(1), k.stride(2),
|
||||
v.stride(0), v.stride(1), v.stride(2),
|
||||
o.stride(0), o.stride(1), o.stride(2),
|
||||
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
69
colossalai/kernel/triton/copy_kv_cache_dest.py
Normal file
69
colossalai/kernel/triton/copy_kv_cache_dest.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
@triton.jit
|
||||
def _fwd_copy_kv_cache_dest(
|
||||
kv_cache_ptr, dest_index_ptr,
|
||||
out,
|
||||
stride_k_bs,
|
||||
stride_k_h,
|
||||
stride_k_d,
|
||||
stride_o_bs,
|
||||
stride_o_h,
|
||||
stride_o_d,
|
||||
head_num,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr
|
||||
):
|
||||
cur_index = tl.program_id(0)
|
||||
offs_h = tl.arange(0, BLOCK_HEAD)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
dest_index = tl.load(dest_index_ptr + cur_index)
|
||||
|
||||
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
|
||||
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
|
||||
|
||||
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
|
||||
o_ptrs = out + dest_index * stride_o_bs + o_offsets
|
||||
|
||||
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
|
||||
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
||||
seq_len = dest_index_ptr.shape[0]
|
||||
head_num = k_ptr.shape[1]
|
||||
head_dim = k_ptr.shape[2]
|
||||
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
|
||||
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
|
||||
|
||||
num_warps = 2
|
||||
|
||||
_fwd_copy_kv_cache_dest[(seq_len,)](
|
||||
k_ptr, dest_index_ptr, out,
|
||||
k_ptr.stride(0),
|
||||
k_ptr.stride(1),
|
||||
k_ptr.stride(2),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
head_num,
|
||||
BLOCK_DMODEL=head_dim,
|
||||
BLOCK_HEAD=triton.next_power_of_2(head_num),
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
83
colossalai/kernel/triton/fused_layernorm.py
Normal file
83
colossalai/kernel/triton/fused_layernorm.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
# CREDITS: These functions are adapted from the Triton tutorial
|
||||
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
Y += row * stride
|
||||
X += row * stride
|
||||
# Compute mean
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.where(cols < N, x - mean, 0.)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# Normalize and apply linear transformation
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd
|
||||
y = x_hat * w + b
|
||||
# Write output
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg,
|
||||
y,
|
||||
weight,
|
||||
bias,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps)
|
||||
return y
|
||||
72
colossalai/kernel/triton/rms_norm.py
Normal file
72
colossalai/kernel/triton/rms_norm.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
this kernel function is modified from
|
||||
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
|
||||
'''
|
||||
@triton.jit
|
||||
def _rms_norm_fwd_fused(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
Y += row * stride
|
||||
X += row * stride
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# Normalize and apply linear transformation
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x_hat = x * rstd
|
||||
y = x_hat * w
|
||||
# Write output
|
||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||
|
||||
|
||||
def rmsnorm_forward(x, weight, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.view(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
# print("BLOCK_SIZE:", BLOCK_SIZE)
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# print(BLOCK_SIZE, num_warps, "block_size, numwarps")
|
||||
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
|
||||
num_warps = 8
|
||||
# enqueue kernel
|
||||
_rms_norm_fwd_fused[(M,)](x_arg, y, weight,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
return y
|
||||
93
colossalai/kernel/triton/rotary_embedding_kernel.py
Normal file
93
colossalai/kernel/triton/rotary_embedding_kernel.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rotary_kernel(
|
||||
q,
|
||||
Cos,
|
||||
Sin,
|
||||
q_bs_stride,
|
||||
q_h_stride,
|
||||
q_d_stride,
|
||||
cos_bs_stride,
|
||||
cos_d_stride,
|
||||
total_len,
|
||||
HEAD_NUM: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
BLOCK_SEQ: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
):
|
||||
current_head_index = tl.program_id(0)
|
||||
current_seq_index = tl.program_id(1)
|
||||
|
||||
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
||||
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||||
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
|
||||
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
|
||||
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
|
||||
|
||||
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
||||
|
||||
q0 = tl.load(q + off_q0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0)
|
||||
q1 = tl.load(q + off_q1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
||||
other=0.0)
|
||||
|
||||
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
||||
|
||||
out0 = q0 * cos - q1 * sin
|
||||
out1 = q0 * sin + q1 * cos
|
||||
|
||||
tl.store(q + off_q0,
|
||||
out0,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||||
tl.store(q + off_q1,
|
||||
out1,
|
||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def rotary_embedding_fwd(q, cos, sin):
|
||||
total_len = q.shape[0]
|
||||
head_num = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SEQ = 32
|
||||
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
_rotary_kernel[grid](
|
||||
q,
|
||||
cos,
|
||||
sin,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
cos.stride(0),
|
||||
cos.stride(1),
|
||||
total_len,
|
||||
HEAD_NUM=head_num,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_SEQ=BLOCK_SEQ,
|
||||
HEAD_DIM=head_dim,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
@@ -11,10 +11,11 @@ except ImportError:
|
||||
|
||||
if HAS_TRITON:
|
||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||
from .softmax_kernel import softmax_kernel
|
||||
from .softmax import softmax_kernel
|
||||
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
input_mask: torch.Tensor, scale: float):
|
||||
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
|
||||
Args:
|
||||
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
|
||||
@@ -36,39 +37,49 @@ if HAS_TRITON:
|
||||
# head_size * num_of_head
|
||||
d_model = q.shape[-1] * q.shape[-2]
|
||||
|
||||
score_output = torch.empty(
|
||||
(batches, H, M, N), device=q.device, dtype=q.dtype)
|
||||
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
q, k, score_output,
|
||||
M, N, K,
|
||||
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
|
||||
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
|
||||
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
|
||||
q,
|
||||
k,
|
||||
score_output,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
q.stride(0),
|
||||
q.stride(2),
|
||||
q.stride(1),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
k.stride(1),
|
||||
score_output.stride(0),
|
||||
score_output.stride(1),
|
||||
score_output.stride(2),
|
||||
score_output.stride(3),
|
||||
scale=scale,
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
# currently manually setting, later on we can use auto-tune config to match best setting
|
||||
BLOCK_SIZE_M=64,
|
||||
BLOCK_SIZE_N=32,
|
||||
BLOCK_SIZE_K=32,
|
||||
GROUP_SIZE_M=8,
|
||||
)
|
||||
|
||||
softmax_output = torch.empty(
|
||||
score_output.shape, device=score_output.device, dtype=score_output.dtype)
|
||||
|
||||
softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype)
|
||||
score_output_shape = score_output.shape
|
||||
|
||||
score_output = score_output.view(-1, score_output.shape[-1])
|
||||
n_rows, n_cols = score_output.shape
|
||||
|
||||
if n_rows <= 350000:
|
||||
|
||||
|
||||
block_size = max(triton.next_power_of_2(n_cols), 2)
|
||||
num_warps = 4
|
||||
if block_size >= 4096:
|
||||
@@ -78,37 +89,39 @@ if HAS_TRITON:
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
softmax_kernel[(n_rows, )](
|
||||
softmax_kernel[(n_rows,)](
|
||||
softmax_output,
|
||||
score_output,
|
||||
score_output.stride(0),
|
||||
n_cols,
|
||||
mask_ptr = input_mask,
|
||||
mask_ptr=input_mask,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
|
||||
else:
|
||||
#TODO: change softmax kernel functions to make it suitable for large size dimension
|
||||
# NOTE: change softmax kernel functions to make it suitable for large size dimension
|
||||
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
|
||||
softmax_output = softmax_output.view(*score_output_shape)
|
||||
|
||||
batches, H, M, K = softmax_output.shape
|
||||
N = v.shape[-1]
|
||||
|
||||
output = torch.empty(
|
||||
(batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
|
||||
output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
batches,
|
||||
H,
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
qkv_gemm_4d_kernel[grid](
|
||||
softmax_output, v, output,
|
||||
M, N, K,
|
||||
softmax_output,
|
||||
v,
|
||||
output,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
softmax_output.stride(0),
|
||||
softmax_output.stride(1),
|
||||
softmax_output.stride(2),
|
||||
@@ -129,7 +142,6 @@ if HAS_TRITON:
|
||||
)
|
||||
return output.view(batches, -1, d_model)
|
||||
|
||||
|
||||
def self_attention_compute_using_triton(qkv,
|
||||
input_mask,
|
||||
layer_past,
|
||||
@@ -152,58 +164,6 @@ if HAS_TRITON:
|
||||
k = k.view(batches, -1, num_of_heads, head_size)
|
||||
v = v.view(batches, -1, num_of_heads, head_size)
|
||||
|
||||
data_output_triton = self_attention_forward_without_fusion(
|
||||
q, k, v, input_mask, scale)
|
||||
data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale)
|
||||
|
||||
return data_output_triton
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
num_rows, num_cols = input.shape
|
||||
block_size = max(triton.next_power_of_2(num_cols), 2)
|
||||
num_warps = 16
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel_2[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
||||
96
colossalai/kernel/triton/softmax.py
Normal file
96
colossalai/kernel/triton/softmax.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
|
||||
if mask is not None:
|
||||
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
|
||||
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
|
||||
|
||||
hidden_dim = input.shape[-1]
|
||||
output = torch.empty_like(input)
|
||||
input = input.view(-1, hidden_dim)
|
||||
if mask is not None:
|
||||
mask = mask.view(-1, hidden_dim)
|
||||
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
|
||||
|
||||
num_rows, num_cols = input.shape
|
||||
block_size = max(triton.next_power_of_2(num_cols), 2)
|
||||
num_warps = 16
|
||||
if block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
if num_rows <= 350000:
|
||||
grid = (num_rows,)
|
||||
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
|
||||
else:
|
||||
grid = lambda meta: ()
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(num_rows, meta["BLOCK_M"]),
|
||||
)
|
||||
|
||||
BLOCK_M = 32
|
||||
if block_size >= 4096:
|
||||
BLOCK_M = 4
|
||||
elif block_size >= 2048:
|
||||
BLOCK_M = 8
|
||||
|
||||
softmax_kernel[grid](output_ptr = output,
|
||||
input_ptr = input,
|
||||
row_stride = input.stride(0),
|
||||
n_rows = num_rows,
|
||||
n_cols = num_cols,
|
||||
mask_ptr = mask,
|
||||
# currently manually setting up size
|
||||
BLOCK_M = 32,
|
||||
BLOCK_SIZE = block_size)
|
||||
|
||||
return output
|
||||
@@ -1,44 +0,0 @@
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
'''
|
||||
softmax kernel is modified based on
|
||||
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
|
||||
'''
|
||||
@triton.jit
|
||||
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
r""" the kernel function for implementing softmax operator
|
||||
Args:
|
||||
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
|
||||
input_ptr: the tensor of input, shape should be (N, hidden_dim)
|
||||
n_cols(tl.constexpr): the number of cols of input
|
||||
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
row_start_ptr = input_ptr + row_idx * row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + col_offsets
|
||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
|
||||
if mask_ptr is not None:
|
||||
# load mask into SRAM
|
||||
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
|
||||
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
|
||||
|
||||
# update
|
||||
row_minus_max = row_minus_max + mask
|
||||
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
output_row_start_ptr = output_ptr + row_idx * row_stride
|
||||
output_ptrs = output_row_start_ptr + col_offsets
|
||||
# Write back output to DRAM
|
||||
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
|
||||
333
colossalai/kernel/triton/token_attention_kernel.py
Normal file
333
colossalai/kernel/triton/token_attention_kernel.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
|
||||
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
|
||||
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
|
||||
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = max_kv_cache_len
|
||||
|
||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
|
||||
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
|
||||
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
|
||||
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = max_kv_cache_len
|
||||
|
||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
alibi_m = tl.load(alibi + current_head)
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = current_batch_start_index + offs_n
|
||||
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
||||
mask=offs_n_new < current_batch_end_index,
|
||||
other=0)
|
||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
|
||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_fwd_1(q,
|
||||
k,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
alibi=None):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
|
||||
assert q_head_dim == k_head_dim
|
||||
assert k_head_dim in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (k_head_dim**0.5)
|
||||
|
||||
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
|
||||
|
||||
num_warps = 4 if k_head_dim <= 64 else 8
|
||||
num_warps = 2
|
||||
|
||||
if alibi is not None:
|
||||
_token_attn_1_alibi_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
sm_scale,
|
||||
alibi,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
HEAD_DIM=k_head_dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
else:
|
||||
_token_attn_1_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
sm_scale,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
attn_out,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
HEAD_DIM=k_head_dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
|
||||
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
|
||||
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
||||
mask=col_offsets < current_batch_seq_len,
|
||||
other=-float('inf')).to(tl.float32)
|
||||
|
||||
row_minus_max = row - tl.max(row, axis=0)
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
|
||||
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
||||
softmax_output,
|
||||
mask=col_offsets < current_batch_seq_len)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
|
||||
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
|
||||
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
|
||||
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
|
||||
_token_attn_softmax_fwd[(batch, head_num)](
|
||||
softmax_logics,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
softmax_prob_out,
|
||||
softmax_logics.stride(0),
|
||||
softmax_logics.stride(1),
|
||||
softmax_prob_out.stride(0),
|
||||
softmax_prob_out.stride(1),
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
|
||||
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
|
||||
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
|
||||
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
current_batch = tl.program_id(0)
|
||||
current_head = tl.program_id(1)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, HEAD_DIM)
|
||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
||||
current_batch_end_index = current_batch_seq_len
|
||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
||||
|
||||
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
|
||||
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
|
||||
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
|
||||
|
||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
||||
for start_n in range(0, current_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
|
||||
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
||||
other=0.0)
|
||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
||||
|
||||
acc = acc.to(tl.float16)
|
||||
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
|
||||
out_ptrs = attn_out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len):
|
||||
if triton.__version__ >= "2.1.0":
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
batch, head = kv_cache_loc.shape[0], v.shape[1]
|
||||
grid = (batch, head)
|
||||
num_warps = 4
|
||||
dim = v.shape[-1]
|
||||
|
||||
_token_attn_2_kernel[grid](
|
||||
prob,
|
||||
v,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seqlen,
|
||||
max_kv_cache_len,
|
||||
kv_cache_loc.stride(0),
|
||||
kv_cache_loc.stride(1),
|
||||
prob.stride(0),
|
||||
prob.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
attn_out.stride(0),
|
||||
attn_out.stride(1),
|
||||
attn_out.stride(2),
|
||||
HEAD_DIM=dim,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def token_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
attn_out,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=None):
|
||||
head_num = k.shape[1]
|
||||
batch_size = kv_cache_seq_len.shape[0]
|
||||
calcu_shape1 = (batch_size, head_num, k.shape[2])
|
||||
total_token_num = k.shape[0]
|
||||
|
||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||
|
||||
token_attn_fwd_1(q.view(calcu_shape1),
|
||||
k,
|
||||
att_m_tensor,
|
||||
kv_cache_loc,
|
||||
kv_cache_start_loc,
|
||||
kv_cache_seq_len,
|
||||
max_len_in_batch,
|
||||
alibi=alibi)
|
||||
|
||||
prob = torch.empty_like(att_m_tensor)
|
||||
|
||||
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||
att_m_tensor = None
|
||||
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
|
||||
max_len_in_batch)
|
||||
|
||||
prob = None
|
||||
|
||||
return
|
||||
Reference in New Issue
Block a user