[Inference] Adapt to Fused rotary (#5348)

* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix
This commit is contained in:
Jianghai
2024-02-07 11:36:04 +08:00
committed by GitHub
parent 35382a7fbf
commit 9f4ab2eb92
5 changed files with 161 additions and 22 deletions

View File

@@ -75,7 +75,6 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)
num_warps = 8 if head_dim > 128 else 4
grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,