[Inference] Fix bugs and docs for feat/online-server (#5598)

* fix test bugs

* add do sample test

* del useless lines

* fix comments

* fix tests

* delete version tag

* delete version tag

* add

* del test sever

* fix test

* fix

* Revert "add"

This reverts commit b9305fb024.
This commit is contained in:
Jianghai
2024-05-08 15:14:06 +08:00
committed by CjhHa1
parent 7bbb28e48b
commit 61a1b2e798
12 changed files with 98 additions and 172 deletions

View File

@@ -527,10 +527,15 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
<<<<<<< HEAD
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
=======
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
@@ -612,6 +617,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
@@ -621,9 +629,10 @@ class InferenceEngine:
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
@@ -738,8 +747,6 @@ class InferenceEngine:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
return finished_sequences