mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * fused kv copy * fused copy * colossalai/kernel/triton/no_pad_rotary_embedding.py * del padding llama * del
This commit is contained in:
@@ -45,21 +45,21 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
k = tl.load(K + offsets_kv)
|
||||
v = tl.load(V + offsets_kv)
|
||||
|
||||
offsets_kvcache = (
|
||||
offsets_kcache = (
|
||||
block_id * stride_cachekb
|
||||
+ cur_kv_head_idx * stride_cachekh
|
||||
+ offsets_in_last_block * stride_cachekbs
|
||||
+ offsets_dmodel * stride_cachekd
|
||||
)
|
||||
offsets_kvcache = (
|
||||
offsets_vcache = (
|
||||
block_id * stride_cachevb
|
||||
+ cur_kv_head_idx * stride_cachevh
|
||||
+ offsets_in_last_block * stride_cachevbs
|
||||
+ offsets_dmodel * stride_cachevd
|
||||
)
|
||||
|
||||
tl.store(KCache + offsets_kvcache, k)
|
||||
tl.store(VCache + offsets_kvcache, v)
|
||||
tl.store(KCache + offsets_kcache, k)
|
||||
tl.store(VCache + offsets_vcache, v)
|
||||
return
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user