mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] add logit processor and request handler (#5166)
* add logit processor and request handler * add * add * add * fix * add search tokens and update func * finish request handler * add running list test * fix test * fix some bug * add * add * fix bugs * fix some bugs * fix bug * fix * fix * add copy fun * del useless attn * fix request status --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
@@ -99,11 +100,13 @@ class KVCacheManager:
|
||||
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
|
||||
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
|
||||
|
||||
def get_total_num_blocks(self) -> int:
|
||||
@property
|
||||
def total_num_blocks(self) -> int:
|
||||
"""Get the total number of logical cache blocks."""
|
||||
return self.num_blocks
|
||||
|
||||
def get_num_available_blocks(self) -> int:
|
||||
@property
|
||||
def num_available_blocks(self) -> int:
|
||||
"""Get the number of available cache blocks."""
|
||||
return self._available_blocks
|
||||
|
||||
@@ -114,6 +117,10 @@ class KVCacheManager:
|
||||
# in the current batch.
|
||||
return self.max_blocks_per_sequence
|
||||
|
||||
def check_allocation(self, seq: Sequence) -> bool:
|
||||
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size
|
||||
return num_blocks_needed <= self.num_available_blocks
|
||||
|
||||
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
|
||||
"""Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block."""
|
||||
block: CacheBlock = self._cache_blocks[block_id]
|
||||
|
Reference in New Issue
Block a user