mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Fixed a bug in the inference frame
This commit is contained in:
committed by
FrankLeeeee
parent
86853a37d5
commit
62fd08ee44
@@ -49,6 +49,7 @@ class InferenceEngine:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
@@ -76,6 +77,7 @@ class InferenceEngine:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
@@ -170,7 +172,11 @@ class InferenceEngine:
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) < self.inference_config.max_input_len
|
||||
), "The length of input prompts must be less than max_input_len."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
@@ -183,13 +189,14 @@ class InferenceEngine:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
@@ -211,14 +218,15 @@ class InferenceEngine:
|
||||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
batch, k_cache, v_cache = self.request_handler.schedule()
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
logits = self.model(
|
||||
batch,
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
)
|
||||
self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
@@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
@@ -49,7 +48,7 @@ class RunningList:
|
||||
def ready_for_prefill(self):
|
||||
if not self.decoding:
|
||||
return len(self.prefill) > 0
|
||||
return len(self.prefill) / len(self.decoding) >= self.ratio
|
||||
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
|
||||
|
||||
def is_empty(self):
|
||||
return not self.decoding and not self.prefill
|
||||
@@ -72,8 +71,9 @@ class RequestHandler:
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
self.running_batch = BatchInfo(is_prompts=False)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True)
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
@@ -81,6 +81,9 @@ class RequestHandler:
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
The main logic of request handler.
|
||||
@@ -90,7 +93,7 @@ class RequestHandler:
|
||||
for lst in reversed(self.waiting_list):
|
||||
if lst:
|
||||
for seq in lst:
|
||||
if seq.prompt_len > self.inference_config.max_input_len:
|
||||
if seq.input_len > self.inference_config.max_input_len:
|
||||
# If the prompt length is longer than max_input_len, abort the sequence.
|
||||
self.abort_sequence(seq.request_id)
|
||||
break
|
||||
@@ -98,9 +101,8 @@ class RequestHandler:
|
||||
if self.cache_manager.check_allocation(seq):
|
||||
# If succeed, add the sequence to running list.
|
||||
self.running_list.append(seq)
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
|
||||
lst.remove(seq)
|
||||
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
|
||||
lst.clear()
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
@@ -115,10 +117,9 @@ class RequestHandler:
|
||||
"""
|
||||
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
||||
assert (
|
||||
req.prompt_len < self.inference_config.max_input_len
|
||||
req.input_len < self.inference_config.max_input_len
|
||||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
|
||||
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
"""
|
||||
@@ -178,9 +179,12 @@ class RequestHandler:
|
||||
"""
|
||||
# do logit processor
|
||||
# NOTE: need to decide the granularity to process logits (sequence or batch)
|
||||
for type in ["top_p", "top_k", "min_p"]:
|
||||
if type in generation_config:
|
||||
logits = logit_processor(type, logits)
|
||||
# for type in ["top_p", "top_k", "min_p"]:
|
||||
# config_dict = generation_config.to_dict()
|
||||
# if type in config_dict:
|
||||
# logits = logit_processor(type, logits, config_dict[type])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# calculate probs
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
@@ -188,7 +192,10 @@ class RequestHandler:
|
||||
|
||||
# sample the next tokens
|
||||
sample_tokens = self._sample(probs, logprobs, generation_config)
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
if not self.prefill_batch.is_empty:
|
||||
self.prefill_batch.update_batch_tokens(sample_tokens)
|
||||
else:
|
||||
self.running_batch.update_batch_tokens(sample_tokens)
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user