[Inference]Fused kv copy into rotary calculation (#5383)

* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix

* fused kv copy

* fused copy

* colossalai/kernel/triton/no_pad_rotary_embedding.py

* del padding llama

* del
This commit is contained in:
Jianghai
2024-02-21 11:31:48 +08:00
committed by GitHub
parent b21aac5bae
commit 730103819d
8 changed files with 391 additions and 498 deletions

View File

@@ -16,7 +16,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_kv_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rotary_embedding,
@@ -281,11 +281,10 @@ class NopadLlamaAttention(LlamaAttention):
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
@@ -300,8 +299,16 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,