mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference/opt] Fused KVCahce Memcopy (#5374)
* fused kv memcopy * add TODO in test_kvcache_copy.py
This commit is contained in:
@@ -301,8 +301,9 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(
|
||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
|
@@ -356,8 +356,9 @@ class PadLlamaAttention(LlamaAttention):
|
||||
if attention_mask is not None:
|
||||
attn_output = pad_input(attn_output, indices, bsz, q_len)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(
|
||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
|
Reference in New Issue
Block a user