fix tests

This commit is contained in:
CjhHa1 2024-04-17 15:56:15 +08:00
parent 402e9918df
commit a6377bb331
8 changed files with 50 additions and 73 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import logging
from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine
@ -103,8 +103,8 @@ class Tracer:
raise KeyError(f"Request {request_id} already exists.")
stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()
return stream
@ -118,6 +118,7 @@ class Tracer:
if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return
self._request_streams[request_id].set_result(None)
@ -127,9 +128,18 @@ class Tracer:
Get new requests from http server.
"""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
@ -172,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for the InferenceEngine class.
This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are
requests in the waiting queue. The generate method yields the outputs
from the InferenceEngine to the caller.
requests. Note that this class does not hold model directly, when incoming a new
request, it first called `add_request` and the Tracer will record the request, putting
it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
"""
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
@ -264,7 +275,7 @@ class AsyncInferenceEngine:
prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream:
"""
Add a request to the background tracker(waitting queue), start the background loop if needed.
Add a request to the background tracker(waiting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
@ -287,14 +298,12 @@ class AsyncInferenceEngine:
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result()
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
# If there is an exception or coroutine is cancelled, abort the request.
self._abort(request_id)
raise e

View File

@ -451,10 +451,6 @@ class InferenceEngine:
"""
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(request_ids, int):
request_ids = [request_ids]
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = []
@ -527,6 +523,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
@ -536,6 +535,7 @@ class InferenceEngine:
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]

View File

@ -327,7 +327,7 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits):

View File

@ -9,6 +9,11 @@ Doc:
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
-H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
Version declaration:
- This is the first version of the API server for Colossal-Inference
- V0 stands for the under development api, such as models, changes should be made to perfect it.
- V1 stands for the currently supported api, such as completion and chat, this is the first version.
"""
import argparse
@ -127,14 +132,6 @@ def add_engine_config(parser):
help="model context length. If unspecified, " "will be automatically derived from the model.",
)
# Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
)
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments
@ -149,28 +146,6 @@ def add_engine_config(parser):
default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
)
# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
return parser

View File

@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.

View File

@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class SequenceTpye:
class MockSequence:
request_id: int
@ -20,7 +20,11 @@ class MockEngine:
async def async_step(self):
self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id):
self.request_id = request_id
@ -37,14 +41,14 @@ class MockEngine:
self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine):
class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine = MockAsyncInferenceEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
@ -74,7 +78,3 @@ async def test_new_requests_event():
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

View File

@ -15,27 +15,25 @@ class SampleEvent:
self.flag = False
def test_request_tracker():
def test_request_tracer():
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished
assert not stream_3.finished
@ -45,28 +43,21 @@ def test_request_tracker():
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 1 in finished
new = tracker.get_new_requests()
assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 4 in finished
new = tracker.get_new_requests()
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
@ -74,4 +65,4 @@ def test_request_tracker():
if __name__ == "__main__":
test_request_tracker()
test_request_tracer()

View File

@ -4,11 +4,13 @@ import sys
import time
import pytest
import ray
import requests
MAX_WAITING_TIME = 300
ray = pytest.importorskip("ray")
pytestmark = pytest.mark.asyncio