mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
@@ -11,6 +11,7 @@ except ImportError:
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||
@triton.jit
|
||||
def _fwd_copy_kv_cache_dest(
|
||||
kv_cache_ptr,
|
||||
@@ -42,6 +43,7 @@ if HAS_TRITON:
|
||||
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
||||
return
|
||||
|
||||
# adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||
@torch.no_grad()
|
||||
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
||||
seq_len = dest_index_ptr.shape[0]
|
||||
|
Reference in New Issue
Block a user