[Inference]Optimize generation process of inference engine (#5356)

* opt inference engine

* fix run_benchmark.sh

* fix generate in engine.py

* rollback tesh_inference_engine.py
This commit is contained in:
yuehuayingxueluo
2024-02-02 15:38:21 +08:00
committed by GitHub
parent 21ad4a27f9
commit 631862f339
3 changed files with 21 additions and 16 deletions

View File

@@ -134,12 +134,16 @@ class InferenceEngine:
def generate(
self,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
generation_config: GenerationConfig = None,
) -> List[str]:
"""
Executing the inference step.
Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
@@ -147,13 +151,23 @@ class InferenceEngine:
"""
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
output_list = []
output_seqs_list = []
output_tokens_list = []
while self.request_handler.check_unfinished_seqs():
output_list += self.step()
output_seqs_list += self.step()
return output_list
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
return output_str
def add_request(
self,
@@ -235,7 +249,6 @@ class InferenceEngine:
List[str]: Decoded finished sequences generated by one step.
"""
output_list = []
batch = self.request_handler.schedule()
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
@@ -251,10 +264,4 @@ class InferenceEngine:
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
# Decode completed sentences.
# TODO : update decoding step
for seq in finished_sequences:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
return output_list
return finished_sequences