mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 12:29:09 +00:00
fix tests
This commit is contained in:
parent
402e9918df
commit
a6377bb331
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
@ -103,8 +103,8 @@ class Tracer:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = RequstStream(request_id)
|
||||
logger.info(f"Added request {request_id}.")
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
@ -118,6 +118,7 @@ class Tracer:
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||
return
|
||||
|
||||
self._request_streams[request_id].set_result(None)
|
||||
@ -127,9 +128,18 @@ class Tracer:
|
||||
Get new requests from http server.
|
||||
"""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if new_request["request_id"] in finished_requests:
|
||||
# The request has been aborted.
|
||||
stream.set_result(None)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
@ -172,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# Return: List[Sequence]
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
"""An asynchronous wrapper for the InferenceEngine class.
|
||||
|
||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||
It uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there are
|
||||
requests in the waiting queue. The generate method yields the outputs
|
||||
from the InferenceEngine to the caller.
|
||||
requests. Note that this class does not hold model directly, when incoming a new
|
||||
request, it first called `add_request` and the Tracer will record the request, putting
|
||||
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||
consider this engine as an interface.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||
@ -264,7 +275,7 @@ class AsyncInferenceEngine:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
if self.start_engine_loop:
|
||||
@ -287,14 +298,12 @@ class AsyncInferenceEngine:
|
||||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
# If there is an exception or coroutine is cancelled, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
@ -451,10 +451,6 @@ class InferenceEngine:
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
if isinstance(request_ids, int):
|
||||
request_ids = [request_ids]
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
output_seqs_list = []
|
||||
@ -527,6 +523,9 @@ class InferenceEngine:
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
@ -536,6 +535,7 @@ class InferenceEngine:
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
|
@ -327,7 +327,7 @@ class RequestHandler:
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def current_requests_in_batch(self) -> int:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
|
@ -9,6 +9,11 @@ Doc:
|
||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
|
||||
Version declaration:
|
||||
- This is the first version of the API server for Colossal-Inference
|
||||
- V0 stands for the under development api, such as models, changes should be made to perfect it.
|
||||
- V1 stands for the currently supported api, such as completion and chat, this is the first version.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -127,14 +132,6 @@ def add_engine_config(parser):
|
||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||
)
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
"--worker-use-ray",
|
||||
action="store_true",
|
||||
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
|
||||
)
|
||||
|
||||
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
|
||||
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||
|
||||
# KV cache arguments
|
||||
@ -149,28 +146,6 @@ def add_engine_config(parser):
|
||||
default=None,
|
||||
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||
)
|
||||
|
||||
# Quantization settings.
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
"-q",
|
||||
type=str,
|
||||
choices=["awq", "gptq", "squeezellm", None],
|
||||
default=None,
|
||||
help="Method used to quantize the weights. If "
|
||||
"None, we first check the `quantization_config` "
|
||||
"attribute in the model config file. If that is "
|
||||
"None, we assume the model weights are not "
|
||||
"quantized and use `dtype` to determine the data "
|
||||
"type of the weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enforce-eager",
|
||||
action="store_true",
|
||||
help="Always use eager-mode PyTorch. If False, "
|
||||
"will use eager mode and CUDA graph in hybrid "
|
||||
"for maximal performance and flexibility.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
he initializer of weight, defaults to normal initializer.
|
||||
|
||||
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
||||
|
||||
::
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||
|
@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceTpye:
|
||||
class MockSequence:
|
||||
request_id: int
|
||||
|
||||
|
||||
@ -20,7 +20,11 @@ class MockEngine:
|
||||
|
||||
async def async_step(self):
|
||||
self.step_calls += 1
|
||||
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
|
||||
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||
|
||||
def add_single_request(self, **kwargs):
|
||||
del kwargs
|
||||
self.add_request_calls += 1
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
@ -37,14 +41,14 @@ class MockEngine:
|
||||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncInferenceEngine):
|
||||
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine = MockAsyncInferenceEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
@ -74,7 +78,3 @@ async def test_new_requests_event():
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_new_requests_event()
|
||||
|
@ -15,27 +15,25 @@ class SampleEvent:
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
def test_request_tracer():
|
||||
tracker = Tracer()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 1
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request(2)
|
||||
stream_3 = tracker.add_request(3)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == 2
|
||||
assert new[1]["request_id"] == 3
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
@ -45,28 +43,21 @@ def test_request_tracker():
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request(1)
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 1 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request(4)
|
||||
tracker.abort_request(4)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 4 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request(5)
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert 2 in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 5
|
||||
assert stream_2.finished
|
||||
@ -74,4 +65,4 @@ def test_request_tracker():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_request_tracker()
|
||||
test_request_tracer()
|
@ -4,11 +4,13 @@ import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_WAITING_TIME = 300
|
||||
|
||||
ray = pytest.importorskip("ray")
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user