mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference]Fused kv copy into rotary calculation (#5383)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix * fused kv copy * fused copy * colossalai/kernel/triton/no_pad_rotary_embedding.py * del padding llama * del
This commit is contained in:
@@ -16,7 +16,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rotary_embedding,
|
||||
@@ -281,11 +281,10 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
)
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
@@ -300,8 +299,16 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(
|
||||
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
|
Reference in New Issue
Block a user