mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -507,9 +507,9 @@ class InferenceEngine:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
request_ids: List[int] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
@@ -527,6 +527,11 @@ class InferenceEngine:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(
|
||||
@@ -535,7 +540,7 @@ class InferenceEngine:
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
@@ -580,13 +585,13 @@ class InferenceEngine:
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: List[int] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
@@ -601,6 +606,7 @@ class InferenceEngine:
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
@@ -614,6 +620,7 @@ class InferenceEngine:
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
print(prompts_token_ids)
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
@@ -632,8 +639,6 @@ class InferenceEngine:
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
if not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
@@ -734,6 +739,9 @@ class InferenceEngine:
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
print("in step", logits)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
Reference in New Issue
Block a user