[Inference] Fix request handler and add recycle logic (#5260)

* fix request handler

* fix comment
This commit is contained in:
Jianghai 2024-01-15 17:50:46 +08:00 committed by GitHub
parent c597678da4
commit d8db500efc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 7 deletions

View File

@ -57,6 +57,9 @@ class RunningList:
def is_empty(self): def is_empty(self):
return not self.decoding and not self.prefill return not self.decoding and not self.prefill
def total_seq_num(self):
return len(self.decoding) + len(self.prefill)
class RequestHandler: class RequestHandler:
""" """
@ -105,6 +108,11 @@ class RequestHandler:
) )
self.abort_sequence(seq.request_id) self.abort_sequence(seq.request_id)
break break
# stop feeding new sequence into running list to assure
if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num:
break
# Try to allocate cache blocks for the sequence. # Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq): if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list. # If succeed, add the sequence to running list.
@ -113,6 +121,7 @@ class RequestHandler:
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
for seq in remove_list: for seq in remove_list:
lst.remove(seq) lst.remove(seq)
if self.running_list.ready_for_prefill(): if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill: for seq in self.running_list.prefill:
seq.mark_running() seq.mark_running()
@ -121,7 +130,12 @@ class RequestHandler:
if not self.running_batch.is_empty: if not self.running_batch.is_empty:
for seq in self.running_batch.sequences_set: for seq in self.running_batch.sequences_set:
self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len)
if recycle:
seq.recycle()
self.running_batch.remove(seq)
self.waiting_list[-1].append(seq)
# the recycled sequences are handled with highest priority.
return self.running_batch return self.running_batch

View File

@ -208,9 +208,9 @@ class KVCacheManager:
# The last allocated block may be either partially or fully occupied. # The last allocated block may be either partially or fully occupied.
# `alloc_local_block_idx` is the index of block to be allocated on provided block table. # `alloc_local_block_idx` is the index of block to be allocated on provided block table.
alloc_local_block_idx = context_len // self.block_size alloc_local_block_idx = context_len // self.block_size
self.allocate_single_block(block_table, alloc_local_block_idx, 1) return self.allocate_single_block(block_table, alloc_local_block_idx)
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int: def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id, """Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block. and updates the provided block table with the allocated block.
@ -221,11 +221,14 @@ class KVCacheManager:
Returns: Returns:
The remaining space required to be allocated (in other blocks). The remaining space required to be allocated (in other blocks).
""" """
assert block_table.dim() == 1 space_asked = 1
block_global_id = block_table[block_local_idx].item() block_global_id = block_table[block_local_idx].item()
if block_global_id < 0: if block_global_id < 0:
# Allocate a new block if the current position is not assigned a block yet # Allocate a new block if the current position is not assigned a block yet
assert self._available_blocks > 0, "No available blocks to allocate." if self._available_blocks <= 0:
# No available blocks to allocate, we free current sequence and return it to
self.free_block_table(block_table)
return True
free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0] free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0]
block: CacheBlock = self._cache_blocks[free_block_id] block: CacheBlock = self._cache_blocks[free_block_id]
block.add_ref() block.add_ref()
@ -235,6 +238,7 @@ class KVCacheManager:
block_table[block_local_idx] = block_global_id block_table[block_local_idx] = block_global_id
block: CacheBlock = self._cache_blocks[block_global_id] block: CacheBlock = self._cache_blocks[block_global_id]
return self._allocate_on_block(block, space_asked) return self._allocate_on_block(block, space_asked)
# only when space asked if fully satisfied, the return value will be zero.
def free_block_table(self, block_table: torch.Tensor) -> None: def free_block_table(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**.""" """Free the logical cache blocks for **a single sequence**."""
@ -269,7 +273,9 @@ class KVCacheManager:
Returns: Returns:
The remaining space required to be allocated (in other blocks). The remaining space required to be allocated (in other blocks).
""" """
assert block.available_space > 0, "No available space on block to allocate." assert (
block.available_space > 0
), "Tried to allocate some space but found no available space left in chosen block."
space_to_allocate = min(block.available_space, space_asked) space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate) block.allocate(space_to_allocate)
return space_asked - space_to_allocate return space_asked - space_to_allocate

View File

@ -134,6 +134,16 @@ class Sequence:
""" """
self.status = RequestStatus.ABORTED self.status = RequestStatus.ABORTED
def recycle(self) -> None:
"""
Recycle a running sequnce to waiitting list
"""
assert (
not self.status.is_finished and not self.status == RequestStatus.ABORTED
), "The running sequence \
is already done but it still in running list"
self.status = RequestStatus.WAITING
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"(request_id={self.request_id}, " f"(request_id={self.request_id}, "