[Inference/opt] Fused KVCahce Memcopy (#5374)

* fused kv memcopy

* add TODO in test_kvcache_copy.py
This commit is contained in:
yuehuayingxueluo
2024-02-07 17:15:42 +08:00
committed by GitHub
parent 58740b5f68
commit 6fb4bcbb24
4 changed files with 75 additions and 30 deletions

View File

@@ -301,8 +301,9 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,

View File

@@ -356,8 +356,9 @@ class PadLlamaAttention(LlamaAttention):
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,