Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
@@ -49,7 +48,7 @@ class RunningList:
def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.ratio
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
def is_empty(self):
return not self.decoding and not self.prefill
@@ -72,8 +71,9 @@ class RequestHandler:
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.running_batch = BatchInfo(is_prompts=False)
self.prefill_batch = BatchInfo(is_prompts=True)
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
@@ -81,6 +81,9 @@ class RequestHandler:
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def schedule(self):
"""
The main logic of request handler.
@@ -90,7 +93,7 @@ class RequestHandler:
for lst in reversed(self.waiting_list):
if lst:
for seq in lst:
if seq.prompt_len > self.inference_config.max_input_len:
if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence.
self.abort_sequence(seq.request_id)
break
@@ -98,9 +101,8 @@ class RequestHandler:
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
lst.remove(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.clear()
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@@ -115,10 +117,9 @@ class RequestHandler:
"""
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
assert (
req.prompt_len < self.inference_config.max_input_len
req.input_len < self.inference_config.max_input_len
), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
def abort_sequence(self, request_id: str):
"""
@@ -178,9 +179,12 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_p", "top_k", "min_p"]:
if type in generation_config:
logits = logit_processor(type, logits)
# for type in ["top_p", "top_k", "min_p"]:
# config_dict = generation_config.to_dict()
# if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@@ -188,7 +192,10 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
self.running_batch.update_batch_tokens(sample_tokens)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
def update(self):
"""