[Inference] Adapt to Fused rotary (#5348)

* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix
This commit is contained in:
Jianghai
2024-02-07 11:36:04 +08:00
committed by GitHub
parent 35382a7fbf
commit 9f4ab2eb92
5 changed files with 161 additions and 22 deletions

View File

@@ -282,11 +282,10 @@ class NopadLlamaAttention(LlamaAttention):
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
@@ -301,7 +300,7 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,