fix tests

This commit is contained in:
CjhHa1 2024-04-17 15:56:15 +08:00
parent 402e9918df
commit a6377bb331
8 changed files with 50 additions and 73 deletions

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from functools import partial from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
@ -103,8 +103,8 @@ class Tracer:
raise KeyError(f"Request {request_id} already exists.") raise KeyError(f"Request {request_id} already exists.")
stream = RequstStream(request_id) stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set() self.new_requests_event.set()
return stream return stream
@ -118,6 +118,7 @@ class Tracer:
if request_id not in self._request_streams or self._request_streams[request_id].finished: if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted. # The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return return
self._request_streams[request_id].set_result(None) self._request_streams[request_id].set_result(None)
@ -127,9 +128,18 @@ class Tracer:
Get new requests from http server. Get new requests from http server.
""" """
new_requests: List[Dict] = [] new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty(): while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait() stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
@ -172,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
for sequence in finished_sequences: for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id) sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0 return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
class AsyncInferenceEngine: class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for the InferenceEngine class.
This class is used to wrap the InferenceEngine class to make it asynchronous. This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are requests. Note that this class does not hold model directly, when incoming a new
requests in the waiting queue. The generate method yields the outputs request, it first called `add_request` and the Tracer will record the request, putting
from the InferenceEngine to the caller. it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
""" """
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
@ -264,7 +275,7 @@ class AsyncInferenceEngine:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream: ) -> RequstStream:
""" """
Add a request to the background tracker(waitting queue), start the background loop if needed. Add a request to the background tracker(waiting queue), start the background loop if needed.
""" """
if not self.background_loop_status: if not self.background_loop_status:
if self.start_engine_loop: if self.start_engine_loop:
@ -287,14 +298,12 @@ class AsyncInferenceEngine:
""" """
Generate output from a request. It receives the request from http server, adds it into the Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence. waitting queue of Async Engine and streams the output sequence.
""" """
try: try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result() return await stream.get_result()
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the # If there is an exception or coroutine is cancelled, abort the request.
# request.
self._abort(request_id) self._abort(request_id)
raise e raise e

View File

@ -451,10 +451,6 @@ class InferenceEngine:
""" """
with torch.inference_mode(): with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(request_ids, int):
request_ids = [request_ids]
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = [] output_seqs_list = []
@ -527,6 +523,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list): if prompts is not None and not isinstance(prompts, list):
prompts = [prompts] prompts = [prompts]
@ -536,6 +535,7 @@ class InferenceEngine:
"input_ids" "input_ids"
] ]
# list of torch Tensor
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor): if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]

View File

@ -327,7 +327,7 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty() return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int: def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits): def search_tokens(self, generation_config: GenerationConfig, logits):

View File

@ -9,6 +9,11 @@ Doc:
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'` -d '{"prompt":"hello, who are you? ","stream":"False"}'`
Version declaration:
- This is the first version of the API server for Colossal-Inference
- V0 stands for the under development api, such as models, changes should be made to perfect it.
- V1 stands for the currently supported api, such as completion and chat, this is the first version.
""" """
import argparse import argparse
@ -127,14 +132,6 @@ def add_engine_config(parser):
help="model context length. If unspecified, " "will be automatically derived from the model.", help="model context length. If unspecified, " "will be automatically derived from the model.",
) )
# Parallel arguments # Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
)
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments # KV cache arguments
@ -149,28 +146,6 @@ def add_engine_config(parser):
default=None, default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
) )
# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
return parser return parser

View File

@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule):
he initializer of weight, defaults to normal initializer. he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place. renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.

View File

@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass @dataclass
class SequenceTpye: class MockSequence:
request_id: int request_id: int
@ -20,7 +20,11 @@ class MockEngine:
async def async_step(self): async def async_step(self):
self.step_calls += 1 self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id): def generate(self, request_id):
self.request_id = request_id self.request_id = request_id
@ -37,14 +41,14 @@ class MockEngine:
self.abort_request_calls += 1 self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine): class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs): def _init_engine(self, *args, **kwargs):
return MockEngine() return MockEngine()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_requests_event(): async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) engine = MockAsyncInferenceEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop() engine.start_background_loop()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0 assert engine.engine.step_calls == 0
@ -74,7 +78,3 @@ async def test_new_requests_event():
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5 assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

View File

@ -15,27 +15,25 @@ class SampleEvent:
self.flag = False self.flag = False
def test_request_tracker(): def test_request_tracer():
tracker = Tracer() tracker = Tracer()
tracker.new_requests_event = SampleEvent() tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1) stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == 1 assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished assert not stream_1.finished
stream_2 = tracker.add_request(2) stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3) stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == 2 assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3 assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished assert not stream_2.finished
assert not stream_3.finished assert not stream_3.finished
@ -45,28 +43,21 @@ def test_request_tracker():
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
tracker.abort_request(1) tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert len(finished) == 1
assert 1 in finished
assert not new assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4) stream_4 = tracker.add_request(4)
tracker.abort_request(4) tracker.abort_request(4)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert len(finished) == 1
assert 4 in finished
assert not new assert not new
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request(5) stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == 5 assert new[0]["request_id"] == 5
assert stream_2.finished assert stream_2.finished
@ -74,4 +65,4 @@ def test_request_tracker():
if __name__ == "__main__": if __name__ == "__main__":
test_request_tracker() test_request_tracer()

View File

@ -4,11 +4,13 @@ import sys
import time import time
import pytest import pytest
import ray
import requests import requests
MAX_WAITING_TIME = 300 MAX_WAITING_TIME = 300
ray = pytest.importorskip("ray")
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio