From 1572af24329db8a26b979d9d5edced62ada1ee2b Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:47:36 +0800 Subject: [PATCH] [Inference] ADD async and sync Api server using FastAPI (#5396) * add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt template --- colossalai/inference/batch_bucket.py | 3 + colossalai/inference/config.py | 19 +- colossalai/inference/core/async_engine.py | 318 ++++++++++++++++++ colossalai/inference/core/engine.py | 34 +- colossalai/inference/core/request_handler.py | 34 +- colossalai/inference/server/__init__.py | 0 colossalai/inference/server/api_server.py | 200 +++++++++++ .../inference/server/completion_service.py | 35 ++ colossalai/inference/server/utils.py | 16 + colossalai/inference/struct.py | 1 + colossalai/shardformer/shard/shardformer.py | 7 +- .../test_async_engine/test_async_engine.py | 80 +++++ .../test_async_engine/test_request_tracker.py | 77 +++++ 13 files changed, 796 insertions(+), 28 deletions(-) create mode 100644 colossalai/inference/core/async_engine.py create mode 100644 colossalai/inference/server/__init__.py create mode 100644 colossalai/inference/server/api_server.py create mode 100644 colossalai/inference/server/completion_service.py create mode 100644 colossalai/inference/server/utils.py create mode 100644 tests/test_infer/test_async_engine/test_async_engine.py create mode 100644 tests/test_infer/test_async_engine/test_request_tracker.py diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index a2a2e74e8..132853770 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -62,6 +62,9 @@ class BatchBucket: def current_batch_size(self): return self._current_batch_size + def __len__(self): + return self._current_batch_size + @property def available_batch_size(self): return self.max_batch_size - self._current_batch_size diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9d7c2c0ad..9266e6927 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,10 +1,10 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ - +import dataclasses import logging from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch import torch.distributed as dist @@ -211,3 +211,18 @@ class InferenceConfig: meta_config[type] = getattr(model_config, type) return GenerationConfig.from_dict(meta_config) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + inference_config_args = {} + for attr in attrs: + if attr in config_dict: + inference_config_args[attr] = config_dict[attr] + else: + inference_config_args[attr] = getattr(cls, attr) + + # Set the attributes from the parsed arguments. + inference_config = cls(**inference_config_args) + return inference_config diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py new file mode 100644 index 000000000..5be36fada --- /dev/null +++ b/colossalai/inference/core/async_engine.py @@ -0,0 +1,318 @@ +import asyncio +from functools import partial +from logging import Logger +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type + +from colossalai.inference.core.engine import InferenceEngine + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: + msg = "Task finished unexpectedly. This should never happen! " + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc + raise AsyncEngineDeadError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + +class AsyncStream: + """A stream of Output for a request that can be + iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopIteration) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self): + result = await self._queue.get() + if result is StopIteration: + raise StopAsyncIteration + elif isinstance(result, Exception): + raise result + return result + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[int] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._request_streams + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None: + """ + Propagate an exception to request streams (all if request_id is None). + """ + if request_id is not None: + self._request_streams[request_id].put(exc) + else: + for stream in self._request_streams.values(): + stream.put(exc) + + def process_finished_request(self, finished_request) -> None: + """Process a finished request from the engine.""" + request_id = finished_request.request_id + + self._request_streams[request_id].put(finished_request) + self.abort_request(request_id) + + def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream: + """ + Add a request to be sent to the engine on the next background + loop iteration. + """ + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) + + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: int, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + Logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if request_id not in self._request_streams or self._request_streams[request_id].finished: + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_requests(self): + """ + Get new requests from http server. + """ + new_requests: List[Dict] = [] + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests + + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class _AsyncInferenceEngine(InferenceEngine): + """ + Async methods for Inference Engine. + """ + + async def async_step(self) -> List[str]: + """ + The async version of Engine.step() + Performs one decoding iteration and returns newly generated results. + + It first schedules the sequences to be executed in the next iteration. + Then, it executes the model and updates the scheduler with the model + outputs. Finally, it decodes the sequences and returns the newly + generated results. + """ + batch = self.request_handler.schedule() + loop = asyncio.get_running_loop() + + # Use run_in_executor to asyncally run the sync method model.forward(). + logits = await loop.run_in_executor( + None, + self.model, + batch, + self.k_cache, + self.v_cache, + ) + + if self.inference_config.pad_input: + logits = logits[:, -1, :] + self.request_handler.search_tokens(self.generation_config, logits) + # Return: List[Sequence] + finished_sequences = self.request_handler.update() + + return finished_sequences, self.request_handler.current_requests_in_batch() > 0 + + +class AsyncInferenceEngine: + """An asynchronous wrapper for LLMEngine. + + This class is used to wrap the InferenceEngine class to make it asynchronous. + It uses asyncio to create a background loop that keeps processing incoming + requests. The LLMEngine is kicked by the generate method when there are + requests in the waiting queue. The generate method yields the outputs + from the InferenceEngine to the caller. + """ + + _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine + + def __init__(self, start_engine_loop: bool = True, **kwargs): + self.engine = self._init_engine(**kwargs) + self.background_loop = None + # reference to the unshielded loop + self._background_loop_unshielded = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + @property + def background_loop_status(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.background_loop_status: + raise RuntimeError("Existing loop is running") + + self._request_tracker.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop()) + self._background_loop_unshielded.add_done_callback( + partial(_raise_exception_on_finish, request_tracker=self._request_tracker) + ) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def _init_engine(self, **kwargs): + return self._engine_class(**kwargs) + + async def step(self): + """ + Run engine to process requests + + Returns True if there are in-progress requests. + """ + new_requests = self._request_tracker.get_new_requests() + for new_request in new_requests: + self.engine.add_single_request(**new_request) + newly_finished_seqs, has_running_requests = await self.engine.async_step() + for seq in newly_finished_seqs: + self._request_tracker.process_finished_request(seq) + + return has_running_requests + + async def _engine_abort(self, request_ids: Iterable[int]): + self.engine.abort_request(request_ids) + + async def abort(self, request_id: int): + """ + Abort a single request + """ + if not self.background_loop_status: + raise RuntimeError("Background loop is not running or launched correctly.") + return self._abort(request_id) + + def _abort(self, request_id: int): + self._request_tracker.abort_request(request_id) + + async def run_engine_loop(self): + processing_requests = False + while True: + if not processing_requests: + await self._request_tracker.wait_for_new_requests() + processing_requests = await self.step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncStream: + """ + Add a request to the background tracker(waitting queue), start the background loop if needed. + """ + if not self.background_loop_status: + if self.start_engine_loop: + self.start_background_loop() + else: + raise RuntimeError("Background loop is not running.") + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + ) + return stream + + async def generate( + self, + request_id: int, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncIterator[str]: + """ + Generate output from a request. It receives the request from http server, adds it into the + waitting queue of Async Engine and streams the output sequence. + + """ + try: + stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) + async for request_output in stream: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self._abort(request_id) + raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index f6b5a6e79..c4b1f0165 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,6 @@ import time from itertools import count -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Iterable import numpy as np import torch @@ -82,13 +82,18 @@ class InferenceEngine: model_type = "nopadding_" + self.model_config.model_type model_policy = model_policy_map[model_type]() - pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) + world_size = self.inference_config.tp_size * self.inference_config.pp_size + + if world_size > 1: + pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) + else: + pg_mesh = None self.model = self._shardformer( model, model_policy, None, - pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, + pg_mesh.get_group_along_axis(TP_AXIS) if pg_mesh else None, ) self.verbose = verbose @@ -425,9 +430,9 @@ class InferenceEngine: def generate( self, - prompts: List[str] = None, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - request_ids: List[int] = None, return_token_ids: bool = False, generation_config: Optional[GenerationConfig] = None, ) -> List[str]: @@ -445,15 +450,19 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): + if generation_config is not None: + self.generation_config = generation_config + if prompts is not None or prompts_token_ids is not None: + if isinstance(prompts, str) and isinstance(request_ids, int): + prompts = [prompts] + request_ids = [request_ids] self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) output_seqs_list = [] total_tokens_list = [] # intuition: If user provide a generation config, we should replace the existing one. - if generation_config is not None: - self.generation_config = generation_config if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." @@ -492,13 +501,13 @@ class InferenceEngine: if isinstance(prompts, (list, tuple)): return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] elif isinstance(prompts, str): - return self.inference_config.rompt_template.format(input_text=prompts) + return self.inference_config.prompt_template.format(input_text=prompts) else: raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") def add_request( self, - request_ids: List[int] = None, + request_ids: Union[List[int], int] = None, prompts: List[str] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, ) -> None: @@ -512,6 +521,7 @@ class InferenceEngine: """ # apply the prompt template to the input prompts + if self.has_prompt_template and prompts is not None: prompts = self.format_prompt(prompts) @@ -525,6 +535,7 @@ class InferenceEngine: prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ "input_ids" ] + print(prompts_token_ids) if isinstance(prompts_token_ids, list): pass @@ -543,8 +554,6 @@ class InferenceEngine: for i in range(prompts_num): if request_ids: - if not isinstance(request_ids, list): - request_ids = [request_ids] assert isinstance( request_ids[0], int ), f"The request_id type must be int, but got {type(request_ids[0])}" @@ -637,6 +646,9 @@ class InferenceEngine: next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) + print("in step", logits) + + self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 327a7e9ce..f5c06b39b 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -263,24 +263,27 @@ class RequestHandler: ), f"Sequence {req.request_id} exceeds input length limit" self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req) - def abort_sequence(self, request_id: str): + def abort_sequence(self, request_id: int): """ Abort the request. """ - seq, priority = self._find_sequence(request_id) - if seq.status == RequestStatus.WAITING: - seq.mark_aborted() - self.waiting_list[priority].remove(seq) - elif seq.status.is_running(): - self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) - self.running_list.remove(seq) - else: - try: - self.done_list.remove(seq) - except: - return + result = self._find_sequence(request_id) + if result is not None: + seq, priority = result + if seq.status == RequestStatus.WAITING: + seq.mark_aborted() + self.waiting_list[priority].remove(seq) + elif seq.status.is_running(): + self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) + self.running_list.remove(seq) + else: + try: + self.done_list.remove(seq) + except: + return + return - def _find_sequence(self, request_id: str) -> Sequence: + def _find_sequence(self, request_id: int) -> Sequence: """ Find the request by request_id. """ @@ -323,6 +326,9 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() + def current_requests_in_batch(self) -> int: + return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size + def search_tokens(self, generation_config: GenerationConfig, logits): """ Sample tokens for finished requests. diff --git a/colossalai/inference/server/__init__.py b/colossalai/inference/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py new file mode 100644 index 000000000..c182c5160 --- /dev/null +++ b/colossalai/inference/server/api_server.py @@ -0,0 +1,200 @@ +""" +Doc: + Feature: + - FastAPI based http server for Colossal-Inference + - Completion Service Supported + Usage: (for local user) + - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` + - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api + - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ + -H 'Content-Type: application/json' \ + -d '{"prompt":"hello, who are you? ","stream":"False"}'` +""" + + +import argparse +import json + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.inference.config import InferenceConfig +from colossalai.inference.server.completion_service import CompletionServing +from colossalai.inference.server.utils import id_generator + +from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None +supported_models_dict = {"Llama_Models": ("llama2-7b",)} +prompt_template_choices = ["llama", "vicuna"] + + +@app.get("/v0/models") +def get_available_models() -> Response: + return JSONResponse(supported_models_dict) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + A request should be a JSON object with the following fields: + - prompts: the prompts to use for the generation. + - stream: whether to stream the results or not. + - other fields: + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", None) + + request_id = id_generator() + generation_config = get_generation_config(request_dict) + results = engine.generate(request_id, prompt, generation_config=generation_config) + + # Streaming case + def stream_results(): + for request_output in results: + ret = {"text": request_output} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + for request_output in results: + if request.is_disconnected(): + # Abort the request if the client disconnects. + engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + ret = {"text": final_output} + return JSONResponse(ret) + + +@app.post("/v1/completion") +async def create_completion(request: Request): + request_dict = await request.json() + generation_config = get_generation_config(request_dict) + generator = await completion_serving.create_completion(request, generation_config) + output = tokenizer.decode(generator.output_token_id) + ret = {"request_id": generator.request_id, "text": output} + return ret + + +def get_generation_config(request): + generation_config = async_engine.engine.generation_config + for arg in request: + if hasattr(generation_config, arg): + generation_config[arg] = request[arg] + return generation_config + + +def add_engine_config(parser): + parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use") + + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, " "will be automatically derived from the model.", + ) + # Parallel arguments + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU", + ) + + parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages") + + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") + + # KV cache arguments + parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") + + parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") + + # generation arguments + parser.add_argument( + "--prompt_template", + choices=prompt_template_choices, + default=None, + help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", + ) + + # Quantization settings. + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", "gptq", "squeezellm", None], + default=None, + help="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Always use eager-mode PyTorch. If False, " + "will use eager mode and CUDA graph in hybrid " + "for maximal performance and flexibility.", + ) + return parser + + +def parse_args(): + parser = argparse.ArgumentParser(description="Colossal-Inference API server.") + + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.", + ) + parser = add_engine_config(parser) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + inference_config = InferenceConfig.from_dict(vars(args)) + model = AutoModelForCausalLM.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model) + async_engine = AsyncInferenceEngine( + start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config + ) + engine = async_engine.engine + completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) + + app.root_path = args.root_path + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/colossalai/inference/server/completion_service.py b/colossalai/inference/server/completion_service.py new file mode 100644 index 000000000..bb2160009 --- /dev/null +++ b/colossalai/inference/server/completion_service.py @@ -0,0 +1,35 @@ +import asyncio + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + +from .utils import id_generator + + +class CompletionServing: + def __init__(self, engine: AsyncInferenceEngine, served_model: str): + self.engine = engine + self.served_model = served_model + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + + async def create_completion(self, request, generation_config): + request_dict = await request.json() + request_id = id_generator() + prompt = request_dict.pop("prompt") + + # it is not a intuitive way + self.engine.engine.generation_config = generation_config + result_generator = self.engine.generate(request_id, prompt=prompt) + + final_res = None + async for res in result_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return {"error_msg": "Client disconnected"} + final_res = res + + return final_res diff --git a/colossalai/inference/server/utils.py b/colossalai/inference/server/utils.py new file mode 100644 index 000000000..c10826f73 --- /dev/null +++ b/colossalai/inference/server/utils.py @@ -0,0 +1,16 @@ +# make it singleton +class NumericIDGenerator: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(NumericIDGenerator, cls).__new__(cls) + cls._instance.current_id = 0 + return cls._instance + + def __call__(self): + self.current_id += 1 + return self.current_id + + +id_generator = NumericIDGenerator() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1fe732df0..9d33c62f8 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -155,6 +155,7 @@ class Sequence: return ( f"(request_id={self.request_id}, " f"prompt={self.prompt}, " + f"output_token_id={self.output_token_id}," f"status={self.status.name}, " f"sample_params={self.sample_params}, " f"input_len={self.input_len}," diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b132f47fd..ebff0429b 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,6 +1,7 @@ import os from typing import Dict, List, Tuple +import torch.distributed as dist import torch.nn as nn from torch import Tensor @@ -36,7 +37,11 @@ class ShardFormer: """ def __init__(self, shard_config: ShardConfig): - self.coordinator = DistCoordinator() + self.is_distributed = dist.is_initialized() + if self.is_distributed: + self.coordinator = DistCoordinator() + else: + self.coordinator = None self.shard_config = shard_config def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py new file mode 100644 index 000000000..ebca11c72 --- /dev/null +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -0,0 +1,80 @@ +import asyncio +from dataclasses import dataclass + +import pytest + +from colossalai.inference.core.async_engine import AsyncInferenceEngine + + +@dataclass +class SequenceTpye: + request_id: int + + +class MockEngine: + def __init__(self): + self.step_calls = 0 + self.add_request_calls = 0 + self.abort_request_calls = 0 + self.request_id = None + + async def async_step(self): + self.step_calls += 1 + return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] + + def generate(self, request_id): + self.request_id = request_id + + def stop_generating(self): + self.request_id = None + + def add_request(self, **kwargs): + del kwargs # Unused + self.add_request_calls += 1 + + def abort_request(self, request_id): + del request_id # Unused + self.abort_request_calls += 1 + + +class MockAsyncLLMEngine(AsyncInferenceEngine): + def _init_engine(self, *args, **kwargs): + return MockEngine() + + +@pytest.mark.asyncio +async def test_new_requests_event(): + engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine.start_background_loop() + await asyncio.sleep(0.01) + assert engine.engine.step_calls == 0 + + await engine.add_request(1, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 1 + assert engine.engine.step_calls == 1 + + await engine.add_request(2, "", None) + engine.engine.generate(2) + await asyncio.sleep(0) + assert engine.engine.add_request_calls == 2 + assert engine.engine.step_calls == 2 + await asyncio.sleep(0) + assert engine.engine.step_calls == 3 + engine.engine.stop_generating() + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + await asyncio.sleep(0) + assert engine.engine.step_calls == 4 + + await engine.add_request(3, "", None) + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + await asyncio.sleep(0.01) + assert engine.engine.add_request_calls == 3 + assert engine.engine.step_calls == 5 + + +if __name__ == "__main__": + test_new_requests_event() diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracker.py new file mode 100644 index 000000000..9a797a862 --- /dev/null +++ b/tests/test_infer/test_async_engine/test_request_tracker.py @@ -0,0 +1,77 @@ +import pytest + +from colossalai.inference.core.async_engine import RequestTracker +from colossalai.inference.struct import Sequence + + +class SampleEvent: + def __init__(self): + self.flag = False + + def set(self): + self.flag = True + + def clear(self): + self.flag = False + + +def test_request_tracker(): + tracker = RequestTracker() + tracker.new_requests_event = SampleEvent() + stream_1 = tracker.add_request(1) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 1 + assert new[0]["request_id"] == 1 + assert not finished + assert not stream_1.finished + + stream_2 = tracker.add_request(2) + stream_3 = tracker.add_request(3) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(new) == 2 + assert new[0]["request_id"] == 2 + assert new[1]["request_id"] == 3 + assert not finished + assert not stream_2.finished + assert not stream_3.finished + + # request_ids must be unique + with pytest.raises(KeyError): + tracker.add_request(1) + assert not tracker.new_requests_event.flag + + tracker.abort_request(1) + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert 1 in finished + assert not new + assert stream_1.finished + + stream_4 = tracker.add_request(4) + tracker.abort_request(4) + assert tracker.new_requests_event.flag + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert 4 in finished + assert not new + assert stream_4.finished + + stream_5 = tracker.add_request(5) + assert tracker.new_requests_event.flag + tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) + new, finished = tracker.get_new_and_finished_requests() + assert not tracker.new_requests_event.flag + assert len(finished) == 1 + assert 2 in finished + assert len(new) == 1 + assert new[0]["request_id"] == 5 + assert stream_2.finished + assert not stream_5.finished + + +if __name__ == "__main__": + test_request_tracker()