mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* [infer] Infer/llama demo (#4503)
* add
* add infer example
* finish
* finish
* stash
* fix
* [Kernels] add inference token attention kernel (#4505)
* add token forward
* fix tests
* fix comments
* add try import triton
* add adapted license
* add tests check
* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager (#4485)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* combine codes (#4509)
* [feature] add KV cache manager for llama & bloom inference (#4495)
* add kv cache memory manager
* add stateinfo during inference
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* file dir change
* [Bug FIx] import llama context ops fix (#4524)
* added _vllm_rms_norm
* change place
* added tests
* added tests
* modify
* adding kernels
* added tests:
* adding kernels
* modify
* added
* updating kernels
* adding tests
* added tests
* kernel change
* submit
* modify
* added
* edit comments
* change name
* change commnets and fix import
* add
* added
* fix
* add ops into init.py
* add
* [Infer] Add TPInferEngine and fix file path (#4532)
* add engine for TP inference
* move file path
* update path
* fix TPInferEngine
* remove unused file
* add engine test demo
* revise TPInferEngine
* fix TPInferEngine, add test
* fix
* Add Inference test for llama (#4508)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [infer] Add Bloom inference policy and replaced methods (#4512)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)
This reverts commit 17cfa57140
.
* [Doc] Add colossal inference doc (#4549)
* create readme
* add readme.md
* fix typos
* [infer] Add Bloom inference policy and replaced methods (#4553)
* add bloom inference methods and policy
* enable pass BatchInferState from model forward
* revise bloom infer layers/policies
* add engine for inference (draft)
* add test for bloom infer
* fix bloom infer policy and flow
* revise bloom test
* fix bloom file path
* remove unused codes
* fix bloom modeling
* fix dir typo
* fix trivial
* fix policy
* clean pr
* trivial fix
* trivial
* Fix Bugs In Llama Model Forward (#4550)
* add kv cache memory manager
* add stateinfo during inference
* add
* add infer example
* finish
* finish
* format
* format
* rename file
* add kv cache test
* revise on BatchInferState
* add inference test for llama
* fix conflict
* feature: add some new features for llama engine
* adapt colossalai triton interface
* Change the parent class of llama policy
* add nvtx
* move llama inference code to tensor_parallel
* fix __init__.py
* rm tensor_parallel
* fix: fix bugs in auto_policy.py
* fix:rm some unused codes
* mv colossalai/tpinference to colossalai/inference/tensor_parallel
* change __init__.py
* save change
* fix engine
* Bug fix: Fix hang
* remove llama_infer_engine.py
* bug fix: fix bugs about infer_state.is_context_stage
* remove pollcies
* fix: delete unused code
* fix: delete unused code
* remove unused coda
* fix conflict
---------
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
* [doc] add colossal inference fig (#4554)
* create readme
* add readme.md
* fix typos
* upload fig
* [NFC] fix docstring for colossal inference (#4555)
Fix docstring and comments in kv cache manager and bloom modeling
* fix docstring in llama modeling (#4557)
* [Infer] check import vllm (#4559)
* change import vllm
* import apply_rotary_pos_emb
* change import location
* [DOC] add installation req (#4561)
* add installation req
* fix
* slight change
* remove empty
* [Feature] rms-norm transfer into inference llama.py (#4563)
* add installation req
* fix
* slight change
* remove empty
* add rmsnorm polciy
* add
* clean codes
* [infer] Fix tp inference engine (#4564)
* fix engine prepare data
* add engine test
* use bloom for testing
* revise on test
* revise on test
* reset shardformer llama (#4569)
* [infer] Fix engine - tensors on different devices (#4570)
* fix diff device in engine
* [codefactor] Feature/colossal inference (#4579)
* code factors
* remove
* change coding (#4581)
* [doc] complete README of colossal inference (#4585)
* complete fig
* Update README.md
* [doc]update readme (#4586)
* update readme
* Update README.md
* bug fix: fix bus in llama and bloom (#4588)
* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward (#4592)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* [Kernel]Rmsnorm fix (#4598)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* add triton rmsnorm
* delete vllm kernel flag
* [Bug Fix]Fix bugs in llama (#4601)
* fix tests
* clean
* clean
* fix bugs
* add
* fix llama non-vllm kernels bug
* modify
* clean codes
* bug fix: remove rotary_positions_ids
---------
Co-authored-by: cuiqing.li <lixx3527@gmail.com>
* [kernel] Add triton layer norm & replace norm for bloom (#4609)
* add layernorm for inference
* add test for layernorm kernel
* add bloom layernorm replacement policy
* trivial: path
* [Infer] Bug fix rotary embedding in llama (#4608)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* [bench] Add bloom inference benchmark (#4621)
* add bloom benchmark
* readme - update benchmark res
* trivial - uncomment for testing (#4622)
* [Infer] add check triton and cuda version for tests (#4627)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* Update sharder.py (#4629)
* [Inference] Hot fix some bugs and typos (#4632)
* fix
* fix test
* fix conflicts
* [typo]Comments fix (#4633)
* fallback
* fix commnets
* bug fix: fix some bugs in test_llama and test_bloom (#4635)
* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)
* fix rotary embedding
* delete print
* fix init seq len bug
* rename pytest
* add benchmark for llama
* refactor codes
* delete useless code
* add check triton and cuda
* delete benchmark and fix infer bugs
* delete benchmark for tests
* delete useless code
* delete bechmark function in utils
* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)
* [Fix] revise TPInferEngine methods and inference tests
* fix llama/bloom infer benchmarks
* fix infer tests
* trivial fix: benchmakrs
* trivial
* trivial: rm print
* modify utils filename for infer ops test (#4657)
* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)
* fix engine funcs
* TPInferEngine: receive shard config in init
* benchmarks: revise TPInferEngine init
* benchmarks: remove pytest decorator
* trivial fix
* use small model for tests
* [NFC] use args for infer benchmarks (#4674)
* revise infer default (#4683)
* [Fix] optimize/shard model in TPInferEngine init (#4684)
* remove using orig model in engine
* revise inference tests
* trivial: rename
---------
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
184 lines
6.9 KiB
Python
184 lines
6.9 KiB
Python
import torch
|
|
import math
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
HAS_TRITON = True
|
|
except ImportError:
|
|
HAS_TRITON = False
|
|
print("please install triton from https://github.com/openai/triton")
|
|
|
|
|
|
if HAS_TRITON:
|
|
'''
|
|
this function is modified from
|
|
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
|
|
'''
|
|
@triton.jit
|
|
def _context_flash_attention_kernel(
|
|
Q, K, V, sm_scale,
|
|
B_Start_Loc, B_Seqlen,
|
|
TMP,
|
|
alibi_ptr,
|
|
Out,
|
|
stride_qbs, stride_qh, stride_qd,
|
|
stride_kbs, stride_kh, stride_kd,
|
|
stride_vbs, stride_vh, stride_vd,
|
|
stride_obs, stride_oh, stride_od,
|
|
stride_tmp_b, stride_tmp_h, stride_tmp_s,
|
|
# suggtest set-up 64, 128, 256, 512
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
|
|
batch_id = tl.program_id(0)
|
|
cur_head = tl.program_id(1)
|
|
start_m = tl.program_id(2)
|
|
|
|
# initialize offsets
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
|
# get batch info
|
|
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
|
|
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
|
|
block_start_loc = BLOCK_M * start_m
|
|
|
|
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
|
|
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
|
|
|
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
|
|
v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
|
|
t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
|
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
|
|
if alibi_ptr is not None:
|
|
alibi_m = tl.load(alibi_ptr + cur_head)
|
|
|
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
|
|
|
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
|
|
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
|
|
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
qk += tl.dot(q, k)
|
|
qk *= sm_scale
|
|
|
|
if alibi_ptr is not None:
|
|
alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
|
|
qk -= alibi_loc * alibi_m
|
|
|
|
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
|
|
m_ij = tl.max(qk, 1)
|
|
p = tl.exp(qk - m_ij[:, None])
|
|
l_ij = tl.sum(p, 1)
|
|
# -- update m_i and l_i
|
|
m_i_new = tl.maximum(m_i, m_ij)
|
|
alpha = tl.exp(m_i - m_i_new)
|
|
beta = tl.exp(m_ij - m_i_new)
|
|
l_i_new = alpha * l_i + beta * l_ij
|
|
# -- update output accumulator --
|
|
# scale p
|
|
p_scale = beta / l_i_new
|
|
p = p * p_scale[:, None]
|
|
# scale acc
|
|
acc_scale = l_i / l_i_new * alpha
|
|
tl.store(t_ptrs, acc_scale)
|
|
acc_scale = tl.load(t_ptrs)
|
|
acc = acc * acc_scale[:, None]
|
|
# update acc
|
|
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
|
|
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
|
|
|
|
p = p.to(v.dtype)
|
|
acc += tl.dot(p, v)
|
|
# update m_i and l_i
|
|
l_i = l_i_new
|
|
m_i = m_i_new
|
|
|
|
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
|
|
out_ptrs = Out + off_o
|
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
|
return
|
|
|
|
|
|
@torch.no_grad()
|
|
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
|
|
BLOCK = 128
|
|
# shape constraints
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
assert Lq == Lk, "context process only supports equal query, key, value length"
|
|
assert Lk == Lv, "context process only supports equal query, key, value length"
|
|
assert Lk in {16, 32, 64, 128}
|
|
|
|
sm_scale = 1.0 / math.sqrt(Lk)
|
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
|
|
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
|
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
|
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
|
|
|
_context_flash_attention_kernel[grid](
|
|
q, k, v, sm_scale,
|
|
b_start_loc, b_seq_len,
|
|
tmp,
|
|
alibi,
|
|
o,
|
|
q.stride(0), q.stride(1), q.stride(2),
|
|
k.stride(0), k.stride(1), k.stride(2),
|
|
v.stride(0), v.stride(1), v.stride(2),
|
|
o.stride(0), o.stride(1), o.stride(2),
|
|
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
|
# manually setting this blcok num, we can use tuning config to futher speed-up
|
|
BLOCK_M=BLOCK,
|
|
BLOCK_DMODEL=Lk,
|
|
BLOCK_N=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=1,
|
|
)
|
|
return
|
|
|
|
@torch.no_grad()
|
|
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
BLOCK = 128
|
|
# shape constraints
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
assert Lq == Lk, "context process only supports equal query, key, value length"
|
|
assert Lk == Lv, "context process only supports equal query, key, value length"
|
|
assert Lk in {16, 32, 64, 128}
|
|
|
|
sm_scale = 1.0 / math.sqrt(Lk)
|
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
|
|
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
|
|
|
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
# num_warps = 4
|
|
_context_flash_attention_kernel[grid](
|
|
q, k, v, sm_scale, b_start_loc, b_seq_len,
|
|
tmp,
|
|
None,
|
|
o,
|
|
q.stride(0), q.stride(1), q.stride(2),
|
|
k.stride(0), k.stride(1), k.stride(2),
|
|
v.stride(0), v.stride(1), v.stride(2),
|
|
o.stride(0), o.stride(1), o.stride(2),
|
|
tmp.stride(0), tmp.stride(1), tmp.stride(2),
|
|
BLOCK_M=BLOCK,
|
|
BLOCK_DMODEL=Lk,
|
|
BLOCK_N=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=1,
|
|
)
|
|
return |