fix bugs in request_handler

This commit is contained in:
yuehuayingxueluo
2024-01-02 13:02:20 +08:00
committed by FrankLeeeee
parent 62fd08ee44
commit 62968588d1
5 changed files with 21 additions and 13 deletions

View File

@@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
@@ -179,10 +180,10 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
# for type in ["top_p", "top_k", "min_p"]:
# config_dict = generation_config.to_dict()
# if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
for type in ["top_p", "top_k", "min_p"]:
config_dict = generation_config.to_dict()
if type in config_dict:
logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
@@ -207,11 +208,12 @@ class RequestHandler:
self.running_list.prefill.clear()
self.prefill_batch.clear_batch()
for seq in self.running_batch.sequences_set:
if seq.check_finish():
self.done_list.append(seq)
self.running_list.remove(seq)
self.running_batch.sequences_set.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
finish_seqs = self.running_batch.fliter_batch()
return self.done_list
for seq in finish_seqs:
self.running_list.remove(seq)
self.cache_manager.free_block_table(seq.block_table)
self.done_list.extend(finish_seqs)
return finish_seqs