adapted to pad_context_forward

This commit is contained in:
yuehuayingxueluo
2024-01-09 13:52:53 +08:00
committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@@ -8,6 +8,9 @@ from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
class RunningList:
@@ -93,17 +96,23 @@ class RequestHandler:
# Try to allocate cache blocks for the sequence using a priority of prompt length.
for lst in reversed(self.waiting_list):
if lst:
remove_list = []
for seq in lst:
if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence.
logger.warning(
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
)
self.abort_sequence(seq.request_id)
break
remove_list.append(seq)
# Try to allocate cache blocks for the sequence.
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
remove_list.append(seq)
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.clear()
for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@@ -130,10 +139,9 @@ class RequestHandler:
"""
Abort the request.
"""
seq, priority = self._find_sequence(request_id)
seq, _ = self._find_sequence(request_id)
if seq.status.is_waiting:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.cache_manager.free_block_table(seq.block_table)
self.running_list.remove(seq)