mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)
* add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline
This commit is contained in:
@@ -320,8 +320,12 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
if use_cuda_kernel:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
else:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
@@ -337,9 +341,16 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
else:
|
||||
if use_cuda_kernel:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
|
Reference in New Issue
Block a user