[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo
2024-04-30 15:52:23 +08:00
committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
14 changed files with 368 additions and 235 deletions

View File

@@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule):
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
@@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule):
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
@@ -355,21 +371,21 @@ class NopadBaichuanAttention(ParallelModule):
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)

View File

@@ -98,15 +98,8 @@ def llama_model_forward(
"""
block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
# during speculative-decoding (`q_len > 1`)
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
@@ -575,6 +568,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
@@ -592,20 +586,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
block_tables,
high_precision,
)
# inference_ops.flash_decoding_attention(
# output_tensor,
# query_states,
# k_cache,
# v_cache,
# sequence_lengths,
# block_tables,
# block_size,
# kv_seq_len,
# fd_inter_tensor.mid_output,
# fd_inter_tensor.mid_output_lse,
# sm_scale,
# )
# attn_output = output_tensor
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
None,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
@@ -627,21 +622,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)