mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -90,9 +90,18 @@ class KVCacheManager:
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||
if config.use_cuda_kernel:
|
||||
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
|
||||
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
|
||||
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(
|
||||
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
|
||||
)
|
||||
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
|
||||
else:
|
||||
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
* self.num_layers
|
||||
@@ -479,7 +488,9 @@ class KVCacheManager:
|
||||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
||||
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _init_device_caches(
|
||||
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the physical cache on the device.
|
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||
@@ -488,6 +499,6 @@ class KVCacheManager:
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
|
||||
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
|
||||
return k_cache, v_cache
|
||||
|
@@ -310,6 +310,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
@@ -332,6 +333,21 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.mid_output_lse,
|
||||
self.alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
@@ -355,21 +371,21 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@@ -98,15 +98,8 @@ def llama_model_forward(
|
||||
"""
|
||||
block_tables = inputmetadata.block_tables
|
||||
sequence_lengths = inputmetadata.sequence_lengths
|
||||
batch_size = inputmetadata.batch_size
|
||||
kv_seq_len = inputmetadata.kv_seq_len
|
||||
|
||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||
# selection should be conducted.
|
||||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
# NOTE (yuanheng-zhao): fow now, only triton kernels support verification process
|
||||
# during speculative-decoding (`q_len > 1`)
|
||||
# We will expicitly disable `use_cuda_kernel` here when speculative-decoding is enabled
|
||||
@@ -575,6 +568,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
@@ -592,20 +586,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
# inference_ops.flash_decoding_attention(
|
||||
# output_tensor,
|
||||
# query_states,
|
||||
# k_cache,
|
||||
# v_cache,
|
||||
# sequence_lengths,
|
||||
# block_tables,
|
||||
# block_size,
|
||||
# kv_seq_len,
|
||||
# fd_inter_tensor.mid_output,
|
||||
# fd_inter_tensor.mid_output_lse,
|
||||
# sm_scale,
|
||||
# )
|
||||
# attn_output = output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.mid_output_lse,
|
||||
None,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if is_verifier:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
@@ -627,21 +622,21 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
Reference in New Issue
Block a user