[Inference] Fix bugs and docs for feat/online-server (#5598)

* fix test bugs

* add do sample test

* del useless lines

* fix comments

* fix tests

* delete version tag

* delete version tag

* add

* del test sever

* fix test

* fix

* Revert "add"

This reverts commit b9305fb024.
This commit is contained in:
Jianghai
2024-05-08 15:14:06 +08:00
committed by CjhHa1
parent 7bbb28e48b
commit 61a1b2e798
12 changed files with 98 additions and 172 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine
@@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
@@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
class RequstStream:
"""A stream of Output for a request that can be
iterated over asynchronously."""
"""
A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.
"""
def __init__(self, request_id: int) -> None:
self.request_id = request_id
@@ -51,6 +57,10 @@ class RequstStream:
class Tracer:
"""
Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
"""
def __init__(self) -> None:
@@ -93,8 +103,8 @@ class Tracer:
raise KeyError(f"Request {request_id} already exists.")
stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()
return stream
@@ -108,6 +118,7 @@ class Tracer:
if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return
self._request_streams[request_id].set_result(None)
@@ -117,9 +128,18 @@ class Tracer:
Get new requests from http server.
"""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
@@ -133,7 +153,8 @@ class Tracer:
class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine.
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
"""
async def async_step(self) -> List[str]:
@@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for the InferenceEngine class.
This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are
requests in the waiting queue. The generate method yields the outputs
from the InferenceEngine to the caller.
requests. Note that this class does not hold model directly, when incoming a new
request, it first called `add_request` and the Tracer will record the request, putting
it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
"""
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
@@ -253,7 +275,7 @@ class AsyncInferenceEngine:
prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream:
"""
Add a request to the background tracker(waitting queue), start the background loop if needed.
Add a request to the background tracker(waiting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
@@ -276,14 +298,12 @@ class AsyncInferenceEngine:
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result()
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
# If there is an exception or coroutine is cancelled, abort the request.
self._abort(request_id)
raise e

View File

@@ -527,10 +527,15 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
<<<<<<< HEAD
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
=======
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
@@ -612,6 +617,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
@@ -621,9 +629,10 @@ class InferenceEngine:
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
@@ -738,8 +747,6 @@ class InferenceEngine:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
return finished_sequences

View File

@@ -328,7 +328,7 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits):