[Inference] add logit processor and request handler (#5166)

* add logit processor and request handler

* add

* add

* add

* fix

* add search tokens and update func

* finish request handler

* add running list test

* fix test

* fix some bug

* add

* add

* fix bugs

* fix some bugs

* fix bug

* fix

* fix

* add copy fun

* del useless attn

* fix request status

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
This commit is contained in:
Jianghai
2023-12-25 12:15:15 +08:00
committed by FrankLeeeee
parent 8daee26989
commit 0e616462a7
10 changed files with 463 additions and 66 deletions

View File

@@ -4,6 +4,7 @@ import torch
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
@@ -99,11 +100,13 @@ class KVCacheManager:
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
def get_total_num_blocks(self) -> int:
@property
def total_num_blocks(self) -> int:
"""Get the total number of logical cache blocks."""
return self.num_blocks
def get_num_available_blocks(self) -> int:
@property
def num_available_blocks(self) -> int:
"""Get the number of available cache blocks."""
return self._available_blocks
@@ -114,6 +117,10 @@ class KVCacheManager:
# in the current batch.
return self.max_blocks_per_sequence
def check_allocation(self, seq: Sequence) -> bool:
num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size
return num_blocks_needed <= self.num_available_blocks
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
"""Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block."""
block: CacheBlock = self._cache_blocks[block_id]