mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
@@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.mid_output_lse,
|
||||
self.alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
@@ -355,21 +371,21 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
Reference in New Issue
Block a user