Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -49,6 +49,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
@@ -76,6 +77,7 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.counter = count()
def _verify_config(self) -> None:
@@ -170,7 +172,11 @@ class InferenceEngine:
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
assert (
len(prompts_token_ids[0]) < self.inference_config.max_input_len
), "The length of input prompts must be less than max_input_len."
prompts_num = len(prompts_token_ids)
@@ -183,13 +189,14 @@ class InferenceEngine:
prompt = None
else:
prompt = prompts[i]
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
block_table,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
@@ -211,14 +218,15 @@ class InferenceEngine:
self.logger.info("Running generation step")
output_list = []
batch, k_cache, v_cache = self.request_handler.schedule()
batch = self.request_handler.schedule()
logits = self.model(
batch,
k_cache,
v_cache,
self.k_cahce,
self.v_cache,
)
self.request_handler.search_tokens(logits, self.generation_config)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()

View File

@@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import KVCacheManager
from colossalai.inference.logit_processors import logit_processor
from colossalai.inference.sampler import *
from colossalai.inference.struct import BatchInfo, Sequence
@@ -49,7 +48,7 @@ class RunningList:
def ready_for_prefill(self):
if not self.decoding:
return len(self.prefill) > 0
return len(self.prefill) / len(self.decoding) >= self.ratio
return len(self.prefill) / len(self.decoding) >= self.prefill_ratio
def is_empty(self):
return not self.decoding and not self.prefill
@@ -72,8 +71,9 @@ class RequestHandler:
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
self.waiting_list: List[List] = [[], [], []]
self.done_list: List[Sequence] = []
self.running_batch = BatchInfo(is_prompts=False)
self.prefill_batch = BatchInfo(is_prompts=True)
device = torch.cuda.current_device()
self.running_batch = BatchInfo(is_prompts=False, device=device)
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
@@ -81,6 +81,9 @@ class RequestHandler:
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
def schedule(self):
"""
The main logic of request handler.
@@ -90,7 +93,7 @@ class RequestHandler:
for lst in reversed(self.waiting_list):
if lst:
for seq in lst:
if seq.prompt_len > self.inference_config.max_input_len:
if seq.input_len > self.inference_config.max_input_len:
# If the prompt length is longer than max_input_len, abort the sequence.
self.abort_sequence(seq.request_id)
break
@@ -98,9 +101,8 @@ class RequestHandler:
if self.cache_manager.check_allocation(seq):
# If succeed, add the sequence to running list.
self.running_list.append(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len)
lst.remove(seq)
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len)
lst.clear()
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@@ -115,10 +117,9 @@ class RequestHandler:
"""
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
assert (
req.prompt_len < self.inference_config.max_input_len
req.input_len < self.inference_config.max_input_len
), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req)
self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req)
def abort_sequence(self, request_id: str):
"""
@@ -178,9 +179,12 @@ class RequestHandler:
"""
# do logit processor
# NOTE: need to decide the granularity to process logits (sequence or batch)
for type in ["top_p", "top_k", "min_p"]:
if type in generation_config:
logits = logit_processor(type, logits)
# for type in ["top_p", "top_k", "min_p"]:
# config_dict = generation_config.to_dict()
# if type in config_dict:
# logits = logit_processor(type, logits, config_dict[type])
torch.cuda.synchronize()
# calculate probs
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
@@ -188,7 +192,10 @@ class RequestHandler:
# sample the next tokens
sample_tokens = self._sample(probs, logprobs, generation_config)
self.running_batch.update_batch_tokens(sample_tokens)
if not self.prefill_batch.is_empty:
self.prefill_batch.update_batch_tokens(sample_tokens)
else:
self.running_batch.update_batch_tokens(sample_tokens)
def update(self):
"""