mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
This commit is contained in:
@@ -527,10 +527,15 @@ class InferenceEngine:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
<<<<<<< HEAD
|
||||
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
=======
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
@@ -612,6 +617,9 @@ class InferenceEngine:
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
@@ -621,9 +629,10 @@ class InferenceEngine:
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
@@ -738,8 +747,6 @@ class InferenceEngine:
|
||||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
Reference in New Issue
Block a user