mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||
msg = "Task finished unexpectedly. This should never happen! "
|
||||
try:
|
||||
try:
|
||||
@@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
||||
|
||||
|
||||
class RequstStream:
|
||||
"""A stream of Output for a request that can be
|
||||
iterated over asynchronously."""
|
||||
"""
|
||||
A stream of Output for a request that can be iterated over asynchronously.
|
||||
Attributes: 1.request_id: The id of the request.
|
||||
2._future: A future that will be set when the request is finished.
|
||||
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||
the `self.future` will be set to done.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: int) -> None:
|
||||
self.request_id = request_id
|
||||
@@ -51,6 +57,10 @@ class RequstStream:
|
||||
class Tracer:
|
||||
"""
|
||||
Recording new requests and finished requests.
|
||||
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||
2._finished_requests: A queue to store the finished requests.
|
||||
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -93,8 +103,8 @@ class Tracer:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = RequstStream(request_id)
|
||||
logger.info(f"Added request {request_id}.")
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
@@ -108,6 +118,7 @@ class Tracer:
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||
return
|
||||
|
||||
self._request_streams[request_id].set_result(None)
|
||||
@@ -117,9 +128,18 @@ class Tracer:
|
||||
Get new requests from http server.
|
||||
"""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if new_request["request_id"] in finished_requests:
|
||||
# The request has been aborted.
|
||||
stream.set_result(None)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
@@ -133,7 +153,8 @@ class Tracer:
|
||||
|
||||
class _AsyncInferenceEngine(InferenceEngine):
|
||||
"""
|
||||
Async methods for Inference Engine.
|
||||
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||
Methods: 1. async_step: The async version of Engine.step()
|
||||
"""
|
||||
|
||||
async def async_step(self) -> List[str]:
|
||||
@@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# Return: List[Sequence]
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
"""An asynchronous wrapper for the InferenceEngine class.
|
||||
|
||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||
It uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there are
|
||||
requests in the waiting queue. The generate method yields the outputs
|
||||
from the InferenceEngine to the caller.
|
||||
requests. Note that this class does not hold model directly, when incoming a new
|
||||
request, it first called `add_request` and the Tracer will record the request, putting
|
||||
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||
consider this engine as an interface.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||
@@ -253,7 +275,7 @@ class AsyncInferenceEngine:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
if self.start_engine_loop:
|
||||
@@ -276,14 +298,12 @@ class AsyncInferenceEngine:
|
||||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
# If there is an exception or coroutine is cancelled, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
@@ -527,10 +527,15 @@ class InferenceEngine:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
<<<<<<< HEAD
|
||||
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
=======
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
@@ -612,6 +617,9 @@ class InferenceEngine:
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
@@ -621,9 +629,10 @@ class InferenceEngine:
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
@@ -738,8 +747,6 @@ class InferenceEngine:
|
||||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
@@ -328,7 +328,7 @@ class RequestHandler:
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def current_requests_in_batch(self) -> int:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
|
Reference in New Issue
Block a user