Fixed a bug in the inference frame

This commit is contained in:
yuehuayingxueluo
2023-12-26 21:34:27 +08:00
committed by FrankLeeeee
parent 86853a37d5
commit 62fd08ee44
8 changed files with 261 additions and 90 deletions

View File

@@ -49,6 +49,7 @@ class InferenceEngine:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
@@ -76,6 +77,7 @@ class InferenceEngine:
self.logger = get_dist_logger(__name__)
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
self.counter = count()
def _verify_config(self) -> None:
@@ -170,7 +172,11 @@ class InferenceEngine:
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"]
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
assert (
len(prompts_token_ids[0]) < self.inference_config.max_input_len
), "The length of input prompts must be less than max_input_len."
prompts_num = len(prompts_token_ids)
@@ -183,13 +189,14 @@ class InferenceEngine:
prompt = None
else:
prompt = prompts[i]
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
block_table,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
@@ -211,14 +218,15 @@ class InferenceEngine:
self.logger.info("Running generation step")
output_list = []
batch, k_cache, v_cache = self.request_handler.schedule()
batch = self.request_handler.schedule()
logits = self.model(
batch,
k_cache,
v_cache,
self.k_cahce,
self.v_cache,
)
self.request_handler.search_tokens(logits, self.generation_config)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()