[Inference] ADD async and sync Api server using FastAPI (#5396)

* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
This commit is contained in:
Jianghai
2024-03-01 14:47:36 +08:00
committed by CjhHa1
parent d482922035
commit 69cd7e069d
13 changed files with 789 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Iterable
import numpy as np
import torch
@@ -507,9 +507,9 @@ class InferenceEngine:
def generate(
self,
prompts: List[str] = None,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
request_ids: List[int] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
@@ -527,6 +527,11 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
@@ -535,7 +540,7 @@ class InferenceEngine:
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
@@ -580,13 +585,13 @@ class InferenceEngine:
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.rompt_template.format(input_text=prompts)
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
request_ids: List[int] = None,
request_ids: Union[List[int], int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
@@ -601,6 +606,7 @@ class InferenceEngine:
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
@@ -614,6 +620,7 @@ class InferenceEngine:
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
print(prompts_token_ids)
if isinstance(prompts_token_ids, list):
pass
@@ -632,8 +639,6 @@ class InferenceEngine:
for i in range(prompts_num):
if request_ids:
if not isinstance(request_ids, list):
request_ids = [request_ids]
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
@@ -734,6 +739,9 @@ class InferenceEngine:
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
print("in step", logits)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
return finished_sequences