mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
Fixed a bug in the inference frame
This commit is contained in:
committed by
FrankLeeeee
parent
86853a37d5
commit
62fd08ee44
@@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
@@ -49,7 +48,7 @@ class RunningList:
|
||||
def ready_for_prefill(self):
|
||||
if not self.decoding:
|
||||
return len(self.prefill) > 0
|
||||
return len(self.prefill) / len(self.decoding) >= self.ratio
|
||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
@@ -72,8 +71,9 @@ class RequestHandler:
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
self.running_batch = BatchInfo(is_prompts=False)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True)
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
@@ -81,6 +81,9 @@ class RequestHandler:
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
The main logic of request handler.
|
||||
@@ -90,7 +93,7 @@ class RequestHandler:
|
||||
for lst in reversed(self.waiting_list):
|
||||
if lst:
|
||||
for seq in lst:
|
||||
if seq.prompt_len > self.inference_config.max_input_len:
|
||||
if seq.input_len > self.inference_config.max_input_len:
|
||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
@@ -98,9 +101,8 @@ class RequestHandler:
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
|
||||
lst.remove(seq)
|
||||
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
lst.clear()
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
@@ -115,10 +117,9 @@ class RequestHandler:
|
||||
"""
|
||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||
assert (
|
||||
req.prompt_len < self.inference_config.max_input_len
|
||||
req.input_len < self.inference_config.max_input_len
|
||||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
|
||||
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
"""
|
||||
@@ -178,9 +179,12 @@ class RequestHandler:
|
||||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
for type in ["top_p", "top_k", "min_p"]:
|
||||
if type in generation_config:
|
||||
logits = logit_processor(type, logits)
|
||||
# for type in ["top_p", "top_k", "min_p"]:
|
||||
# config_dict = generation_config.to_dict()
|
||||
# if type in config_dict:
|
||||
# logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
@@ -188,7 +192,10 @@ class RequestHandler:
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||
else:
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user