[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao
2024-01-30 16:06:09 +08:00
committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
14 changed files with 171 additions and 145 deletions

View File

@@ -171,7 +171,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(

View File

@@ -226,7 +226,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(