From a6377bb3311e76f9aa4d356e3f5a449f7889f1f5 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 17 Apr 2024 15:56:15 +0800 Subject: [PATCH] fix tests --- colossalai/inference/core/async_engine.py | 33 ++++++++++------- colossalai/inference/core/engine.py | 8 ++--- colossalai/inference/core/request_handler.py | 2 +- colossalai/inference/server/api_server.py | 35 +++---------------- colossalai/shardformer/layer/embedding.py | 2 +- .../test_async_engine/test_async_engine.py | 16 ++++----- ...uest_tracker.py => test_request_tracer.py} | 23 ++++-------- tests/test_infer/test_server.py | 4 ++- 8 files changed, 50 insertions(+), 73 deletions(-) rename tests/test_infer/test_async_engine/{test_request_tracker.py => test_request_tracer.py} (73%) diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index 9c630177d..6f7ab15d8 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -1,7 +1,7 @@ import asyncio import logging from functools import partial -from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type from colossalai.inference.core.engine import InferenceEngine @@ -103,8 +103,8 @@ class Tracer: raise KeyError(f"Request {request_id} already exists.") stream = RequstStream(request_id) + logger.info(f"Added request {request_id}.") self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) - self.new_requests_event.set() return stream @@ -118,6 +118,7 @@ class Tracer: if request_id not in self._request_streams or self._request_streams[request_id].finished: # The request has already finished or been aborted. + # The requests in new_requests will be aborted when try to get them(if marked aborted) return self._request_streams[request_id].set_result(None) @@ -127,9 +128,18 @@ class Tracer: Get new requests from http server. """ new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() + if new_request["request_id"] in finished_requests: + # The request has been aborted. + stream.set_result(None) + continue self._request_streams[stream.request_id] = stream new_requests.append(new_request) @@ -172,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine): if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) - # Return: List[Sequence] + finished_sequences = self.request_handler.update() for sequence in finished_sequences: sequence.output = self.tokenizer.decode(sequence.output_token_id) - return finished_sequences, self.request_handler.current_requests_in_batch() > 0 + return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0 class AsyncInferenceEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for the InferenceEngine class. This class is used to wrap the InferenceEngine class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there are - requests in the waiting queue. The generate method yields the outputs - from the InferenceEngine to the caller. + requests. Note that this class does not hold model directly, when incoming a new + request, it first called `add_request` and the Tracer will record the request, putting + it to the background `InferenceEngine`(done in background loop) to process. You can + consider this engine as an interface. """ _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine @@ -264,7 +275,7 @@ class AsyncInferenceEngine: prompt_token_ids: Optional[List[int]] = None, ) -> RequstStream: """ - Add a request to the background tracker(waitting queue), start the background loop if needed. + Add a request to the background tracker(waiting queue), start the background loop if needed. """ if not self.background_loop_status: if self.start_engine_loop: @@ -287,14 +298,12 @@ class AsyncInferenceEngine: """ Generate output from a request. It receives the request from http server, adds it into the waitting queue of Async Engine and streams the output sequence. - """ try: stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) return await stream.get_result() except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. + # If there is an exception or coroutine is cancelled, abort the request. self._abort(request_id) raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 06b716ea9..8796e7492 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -451,10 +451,6 @@ class InferenceEngine: """ with torch.inference_mode(): if prompts is not None or prompts_token_ids is not None: - if isinstance(prompts, str): - prompts = [prompts] - if isinstance(request_ids, int): - request_ids = [request_ids] self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) output_seqs_list = [] @@ -527,6 +523,9 @@ class InferenceEngine: block_size = self.inference_config.block_size + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + if prompts is not None and not isinstance(prompts, list): prompts = [prompts] @@ -536,6 +535,7 @@ class InferenceEngine: "input_ids" ] + # list of torch Tensor if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids[0], torch.Tensor): prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index d4c4de299..e529c5b65 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -327,7 +327,7 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() - def current_requests_in_batch(self) -> int: + def total_requests_in_batch_bucket(self) -> int: return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size def search_tokens(self, generation_config: GenerationConfig, logits): diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 60ccf15fc..1904581dc 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -9,6 +9,11 @@ Doc: - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ -H 'Content-Type: application/json' \ -d '{"prompt":"hello, who are you? ","stream":"False"}'` + + Version declaration: + - This is the first version of the API server for Colossal-Inference + - V0 stands for the under development api, such as models, changes should be made to perfect it. + - V1 stands for the currently supported api, such as completion and chat, this is the first version. """ import argparse @@ -127,14 +132,6 @@ def add_engine_config(parser): help="model context length. If unspecified, " "will be automatically derived from the model.", ) # Parallel arguments - parser.add_argument( - "--worker-use-ray", - action="store_true", - help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU", - ) - - parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages") - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") # KV cache arguments @@ -149,28 +146,6 @@ def add_engine_config(parser): default=None, help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", ) - - # Quantization settings. - parser.add_argument( - "--quantization", - "-q", - type=str, - choices=["awq", "gptq", "squeezellm", None], - default=None, - help="Method used to quantize the weights. If " - "None, we first check the `quantization_config` " - "attribute in the model config file. If that is " - "None, we assume the model weights are not " - "quantized and use `dtype` to determine the data " - "type of the weights.", - ) - parser.add_argument( - "--enforce-eager", - action="store_true", - help="Always use eager-mode PyTorch. If False, " - "will use eager mode and CUDA graph in hybrid " - "for maximal performance and flexibility.", - ) return parser diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index ab2503a60..8321adc5c 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - + :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py index ebca11c72..bea05a372 100644 --- a/tests/test_infer/test_async_engine/test_async_engine.py +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine @dataclass -class SequenceTpye: +class MockSequence: request_id: int @@ -20,7 +20,11 @@ class MockEngine: async def async_step(self): self.step_calls += 1 - return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] + return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False) + + def add_single_request(self, **kwargs): + del kwargs + self.add_request_calls += 1 def generate(self, request_id): self.request_id = request_id @@ -37,14 +41,14 @@ class MockEngine: self.abort_request_calls += 1 -class MockAsyncLLMEngine(AsyncInferenceEngine): +class MockAsyncInferenceEngine(AsyncInferenceEngine): def _init_engine(self, *args, **kwargs): return MockEngine() @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine = MockAsyncInferenceEngine(worker_use_ray=False, engine_use_ray=False) engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -74,7 +78,3 @@ async def test_new_requests_event(): await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == 5 - - -if __name__ == "__main__": - test_new_requests_event() diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracer.py similarity index 73% rename from tests/test_infer/test_async_engine/test_request_tracker.py rename to tests/test_infer/test_async_engine/test_request_tracer.py index 4b15d46c1..14bcb9628 100644 --- a/tests/test_infer/test_async_engine/test_request_tracker.py +++ b/tests/test_infer/test_async_engine/test_request_tracer.py @@ -15,27 +15,25 @@ class SampleEvent: self.flag = False -def test_request_tracker(): +def test_request_tracer(): tracker = Tracer() tracker.new_requests_event = SampleEvent() stream_1 = tracker.add_request(1) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 1 assert new[0]["request_id"] == 1 - assert not finished assert not stream_1.finished stream_2 = tracker.add_request(2) stream_3 = tracker.add_request(3) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 2 assert new[0]["request_id"] == 2 assert new[1]["request_id"] == 3 - assert not finished assert not stream_2.finished assert not stream_3.finished @@ -45,28 +43,21 @@ def test_request_tracker(): assert not tracker.new_requests_event.flag tracker.abort_request(1) - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 1 in finished + new = tracker.get_new_requests() assert not new - assert stream_1.finished stream_4 = tracker.add_request(4) tracker.abort_request(4) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 4 in finished + new = tracker.get_new_requests() assert not new assert stream_4.finished stream_5 = tracker.add_request(5) assert tracker.new_requests_event.flag tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag - assert len(finished) == 1 - assert 2 in finished assert len(new) == 1 assert new[0]["request_id"] == 5 assert stream_2.finished @@ -74,4 +65,4 @@ def test_request_tracker(): if __name__ == "__main__": - test_request_tracker() + test_request_tracer() diff --git a/tests/test_infer/test_server.py b/tests/test_infer/test_server.py index 05ac5a264..e7cf8d515 100644 --- a/tests/test_infer/test_server.py +++ b/tests/test_infer/test_server.py @@ -4,11 +4,13 @@ import sys import time import pytest -import ray import requests MAX_WAITING_TIME = 300 +ray = pytest.importorskip("ray") + + pytestmark = pytest.mark.asyncio