mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -79,10 +79,10 @@ class KVCacheManager:
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
||||
if verbose:
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches()
|
||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
* self.num_layers
|
||||
@@ -297,15 +297,12 @@ class KVCacheManager:
|
||||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
||||
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the physical cache on the device.
|
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||
with shape of [num_blocks, num_kv_heads, head_size, block_size]
|
||||
with shape of [num_blocks, num_kv_heads, block_size, head_size]
|
||||
"""
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
|
||||
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
|
Reference in New Issue
Block a user