mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference] Fix API server, test and example (#5712)
* fix api server * fix generation config * fix api server * fix comments * fix infer hanging bug * resolve comments, change backend to free port
This commit is contained in:
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
|
||||
# CLI logger
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -168,26 +169,44 @@ class _AsyncInferenceEngine(InferenceEngine):
|
||||
generated results.
|
||||
"""
|
||||
batch = self.request_handler.schedule()
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# Use run_in_executor to asyncally run the sync method model.forward().
|
||||
logits = await loop.run_in_executor(
|
||||
None,
|
||||
self.model,
|
||||
batch,
|
||||
model_executable,
|
||||
input_token_ids,
|
||||
output_tensor,
|
||||
input_meta_data,
|
||||
self.k_cache,
|
||||
self.v_cache,
|
||||
)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||
return finished_sequences, not self.request_handler.running_list.is_empty()
|
||||
|
||||
def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None):
|
||||
prompts = [prompt]
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict)
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
@@ -240,7 +259,6 @@ class AsyncInferenceEngine:
|
||||
for new_request in new_requests:
|
||||
self.engine.add_single_request(**new_request)
|
||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||
|
||||
for seq in newly_finished_seqs:
|
||||
self._request_tracer.process_finished_request(seq)
|
||||
|
||||
@@ -273,6 +291,7 @@ class AsyncInferenceEngine:
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
generation_config=None,
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||
@@ -286,6 +305,7 @@ class AsyncInferenceEngine:
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
return stream
|
||||
|
||||
@@ -294,13 +314,16 @@ class AsyncInferenceEngine:
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
generation_config=None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
stream = await self.add_request(
|
||||
request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config
|
||||
)
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
|
@@ -154,7 +154,6 @@ class InferenceEngine:
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
@@ -589,7 +588,7 @@ class InferenceEngine:
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user