[Refactor] Integrated some lightllm kernels into token-attention (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
Cuiqing Li
2023-10-19 22:22:47 +08:00
committed by GitHub
parent 11009103be
commit 3a41e8304e
20 changed files with 160 additions and 1555 deletions

View File

@@ -11,6 +11,7 @@ except ImportError:
if HAS_TRITON:
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit
def _fwd_copy_kv_cache_dest(
kv_cache_ptr,
@@ -42,6 +43,7 @@ if HAS_TRITON:
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return
# adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0]