[Inference]Fused kv copy into rotary calculation (#5383)

* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
This commit is contained in:
Jianghai
2024-02-21 11:31:48 +08:00
committed by GitHub
parent b21aac5bae
commit 730103819d
8 changed files with 391 additions and 498 deletions

View File

@@ -45,21 +45,21 @@ def _copy_to_kvcache_seqlen1_kernel(
k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)
offsets_kvcache = (
offsets_kcache = (
block_id * stride_cachekb
+ cur_kv_head_idx * stride_cachekh
+ offsets_in_last_block * stride_cachekbs
+ offsets_dmodel * stride_cachekd
)
offsets_kvcache = (
offsets_vcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)
tl.store(KCache + offsets_kvcache, k)
tl.store(VCache + offsets_kvcache, v)
tl.store(KCache + offsets_kcache, k)
tl.store(VCache + offsets_vcache, v)
return