mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .block_cache import CacheBlock
|
||||
from .kvcache_manager import KVCacheManager
|
||||
from .kvcache_manager import KVCacheManager, RPCKVCacheManager
|
||||
|
||||
__all__ = ["CacheBlock", "KVCacheManager"]
|
||||
__all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"]
|
||||
|
@@ -497,3 +497,80 @@ class KVCacheManager:
|
||||
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
class RPCKVCacheManager(KVCacheManager):
|
||||
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.device = get_current_device()
|
||||
self.config = config
|
||||
|
||||
# Parallel settings
|
||||
self.tp_size = config.tp_size
|
||||
# Model settings
|
||||
self.dtype = config.dtype
|
||||
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
|
||||
self.num_layers = model_config.num_hidden_layers
|
||||
self.head_num = model_config.num_attention_heads
|
||||
self.head_size = model_config.hidden_size // self.head_num
|
||||
if hasattr(model_config, "num_key_value_heads"):
|
||||
self.kv_head_num = model_config.num_key_value_heads
|
||||
else:
|
||||
self.kv_head_num = self.head_num
|
||||
|
||||
if config.kv_cache_dtype is None:
|
||||
self.kv_cache_dtype = config.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = config.kv_cache_dtype
|
||||
|
||||
assert (
|
||||
self.kv_head_num % self.tp_size == 0
|
||||
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
|
||||
self.kv_head_num //= self.tp_size
|
||||
self.beam_width = config.beam_width
|
||||
self.max_batch_size = config.max_batch_size
|
||||
self.max_input_length = config.max_input_len
|
||||
self.max_output_length = config.max_output_len
|
||||
# Cache block settings
|
||||
self.block_size = config.block_size
|
||||
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
|
||||
self.max_blocks_per_sequence = (
|
||||
self.max_input_length + self.max_output_length + self.block_size - 1
|
||||
) // self.block_size
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Logical cache blocks allocation
|
||||
self._available_blocks = self.num_blocks
|
||||
self._cache_blocks = tuple(self._init_logical_caches())
|
||||
# block availablity state 0->allocated, 1->free
|
||||
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
|
||||
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
|
||||
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
|
||||
|
||||
def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
|
||||
# Physical cache allocation
|
||||
if self.config.use_cuda_kernel:
|
||||
x = 16 // torch.tensor([], dtype=self.config.dtype).element_size()
|
||||
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
|
||||
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
self.logger.info(
|
||||
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
|
||||
)
|
||||
else:
|
||||
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
|
||||
kalloc_shape = alloc_shape
|
||||
valloc_shape = alloc_shape
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
return kalloc_shape, valloc_shape
|
||||
|
||||
def get_kv_cache(self):
|
||||
"""Get k_cache and v_cache"""
|
||||
return NotImplementedError
|
||||
|
||||
def _init_logical_caches(self):
|
||||
"""Initialize the logical cache blocks."""
|
||||
blocks = []
|
||||
for i in range(self.num_blocks):
|
||||
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None)
|
||||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
Reference in New Issue
Block a user