mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[Inference] Adapt to Fused rotary (#5348)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix
This commit is contained in:
@@ -282,11 +282,10 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
|
||||
)
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
@@ -301,7 +300,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
|
Reference in New Issue
Block a user