mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-08 12:54:19 +00:00
fix tests
This commit is contained in:
parent
402e9918df
commit
a6377bb331
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
|
||||||
@ -103,8 +103,8 @@ class Tracer:
|
|||||||
raise KeyError(f"Request {request_id} already exists.")
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
stream = RequstStream(request_id)
|
stream = RequstStream(request_id)
|
||||||
|
logger.info(f"Added request {request_id}.")
|
||||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||||
|
|
||||||
self.new_requests_event.set()
|
self.new_requests_event.set()
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
@ -118,6 +118,7 @@ class Tracer:
|
|||||||
|
|
||||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||||
# The request has already finished or been aborted.
|
# The request has already finished or been aborted.
|
||||||
|
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._request_streams[request_id].set_result(None)
|
self._request_streams[request_id].set_result(None)
|
||||||
@ -127,9 +128,18 @@ class Tracer:
|
|||||||
Get new requests from http server.
|
Get new requests from http server.
|
||||||
"""
|
"""
|
||||||
new_requests: List[Dict] = []
|
new_requests: List[Dict] = []
|
||||||
|
finished_requests: Set[int] = set()
|
||||||
|
|
||||||
|
while not self._finished_requests.empty():
|
||||||
|
request_id = self._finished_requests.get_nowait()
|
||||||
|
finished_requests.add(request_id)
|
||||||
|
|
||||||
while not self._new_requests.empty():
|
while not self._new_requests.empty():
|
||||||
stream, new_request = self._new_requests.get_nowait()
|
stream, new_request = self._new_requests.get_nowait()
|
||||||
|
if new_request["request_id"] in finished_requests:
|
||||||
|
# The request has been aborted.
|
||||||
|
stream.set_result(None)
|
||||||
|
continue
|
||||||
self._request_streams[stream.request_id] = stream
|
self._request_streams[stream.request_id] = stream
|
||||||
new_requests.append(new_request)
|
new_requests.append(new_request)
|
||||||
|
|
||||||
@ -172,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
|
|||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
# Return: List[Sequence]
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
for sequence in finished_sequences:
|
for sequence in finished_sequences:
|
||||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||||
|
|
||||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceEngine:
|
class AsyncInferenceEngine:
|
||||||
"""An asynchronous wrapper for LLMEngine.
|
"""An asynchronous wrapper for the InferenceEngine class.
|
||||||
|
|
||||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||||
It uses asyncio to create a background loop that keeps processing incoming
|
It uses asyncio to create a background loop that keeps processing incoming
|
||||||
requests. The LLMEngine is kicked by the generate method when there are
|
requests. Note that this class does not hold model directly, when incoming a new
|
||||||
requests in the waiting queue. The generate method yields the outputs
|
request, it first called `add_request` and the Tracer will record the request, putting
|
||||||
from the InferenceEngine to the caller.
|
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||||
|
consider this engine as an interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||||
@ -264,7 +275,7 @@ class AsyncInferenceEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
) -> RequstStream:
|
) -> RequstStream:
|
||||||
"""
|
"""
|
||||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||||
"""
|
"""
|
||||||
if not self.background_loop_status:
|
if not self.background_loop_status:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
@ -287,14 +298,12 @@ class AsyncInferenceEngine:
|
|||||||
"""
|
"""
|
||||||
Generate output from a request. It receives the request from http server, adds it into the
|
Generate output from a request. It receives the request from http server, adds it into the
|
||||||
waitting queue of Async Engine and streams the output sequence.
|
waitting queue of Async Engine and streams the output sequence.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||||
return await stream.get_result()
|
return await stream.get_result()
|
||||||
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
# If there is an exception or coroutine is cancelled, abort the
|
# If there is an exception or coroutine is cancelled, abort the request.
|
||||||
# request.
|
|
||||||
self._abort(request_id)
|
self._abort(request_id)
|
||||||
raise e
|
raise e
|
||||||
|
@ -451,10 +451,6 @@ class InferenceEngine:
|
|||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
if isinstance(prompts, str):
|
|
||||||
prompts = [prompts]
|
|
||||||
if isinstance(request_ids, int):
|
|
||||||
request_ids = [request_ids]
|
|
||||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||||
|
|
||||||
output_seqs_list = []
|
output_seqs_list = []
|
||||||
@ -527,6 +523,9 @@ class InferenceEngine:
|
|||||||
|
|
||||||
block_size = self.inference_config.block_size
|
block_size = self.inference_config.block_size
|
||||||
|
|
||||||
|
if request_ids is not None and not isinstance(request_ids, list):
|
||||||
|
request_ids = [request_ids]
|
||||||
|
|
||||||
if prompts is not None and not isinstance(prompts, list):
|
if prompts is not None and not isinstance(prompts, list):
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
|
|
||||||
@ -536,6 +535,7 @@ class InferenceEngine:
|
|||||||
"input_ids"
|
"input_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# list of torch Tensor
|
||||||
if isinstance(prompts_token_ids, list):
|
if isinstance(prompts_token_ids, list):
|
||||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||||
|
@ -327,7 +327,7 @@ class RequestHandler:
|
|||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
def current_requests_in_batch(self) -> int:
|
def total_requests_in_batch_bucket(self) -> int:
|
||||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||||
|
|
||||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||||
|
@ -9,6 +9,11 @@ Doc:
|
|||||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
|
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||||
|
|
||||||
|
Version declaration:
|
||||||
|
- This is the first version of the API server for Colossal-Inference
|
||||||
|
- V0 stands for the under development api, such as models, changes should be made to perfect it.
|
||||||
|
- V1 stands for the currently supported api, such as completion and chat, this is the first version.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -127,14 +132,6 @@ def add_engine_config(parser):
|
|||||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||||
)
|
)
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument(
|
|
||||||
"--worker-use-ray",
|
|
||||||
action="store_true",
|
|
||||||
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
|
|
||||||
|
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||||
|
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
@ -149,28 +146,6 @@ def add_engine_config(parser):
|
|||||||
default=None,
|
default=None,
|
||||||
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Quantization settings.
|
|
||||||
parser.add_argument(
|
|
||||||
"--quantization",
|
|
||||||
"-q",
|
|
||||||
type=str,
|
|
||||||
choices=["awq", "gptq", "squeezellm", None],
|
|
||||||
default=None,
|
|
||||||
help="Method used to quantize the weights. If "
|
|
||||||
"None, we first check the `quantization_config` "
|
|
||||||
"attribute in the model config file. If that is "
|
|
||||||
"None, we assume the model weights are not "
|
|
||||||
"quantized and use `dtype` to determine the data "
|
|
||||||
"type of the weights.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="Always use eager-mode PyTorch. If False, "
|
|
||||||
"will use eager mode and CUDA graph in hybrid "
|
|
||||||
"for maximal performance and flexibility.",
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||||||
he initializer of weight, defaults to normal initializer.
|
he initializer of weight, defaults to normal initializer.
|
||||||
|
|
||||||
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
||||||
|
::
|
||||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||||
|
@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SequenceTpye:
|
class MockSequence:
|
||||||
request_id: int
|
request_id: int
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,11 @@ class MockEngine:
|
|||||||
|
|
||||||
async def async_step(self):
|
async def async_step(self):
|
||||||
self.step_calls += 1
|
self.step_calls += 1
|
||||||
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
|
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||||
|
|
||||||
|
def add_single_request(self, **kwargs):
|
||||||
|
del kwargs
|
||||||
|
self.add_request_calls += 1
|
||||||
|
|
||||||
def generate(self, request_id):
|
def generate(self, request_id):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@ -37,14 +41,14 @@ class MockEngine:
|
|||||||
self.abort_request_calls += 1
|
self.abort_request_calls += 1
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncLLMEngine(AsyncInferenceEngine):
|
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||||
def _init_engine(self, *args, **kwargs):
|
def _init_engine(self, *args, **kwargs):
|
||||||
return MockEngine()
|
return MockEngine()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_requests_event():
|
async def test_new_requests_event():
|
||||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
engine = MockAsyncInferenceEngine(worker_use_ray=False, engine_use_ray=False)
|
||||||
engine.start_background_loop()
|
engine.start_background_loop()
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.step_calls == 0
|
assert engine.engine.step_calls == 0
|
||||||
@ -74,7 +78,3 @@ async def test_new_requests_event():
|
|||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == 5
|
assert engine.engine.step_calls == 5
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_new_requests_event()
|
|
||||||
|
@ -15,27 +15,25 @@ class SampleEvent:
|
|||||||
self.flag = False
|
self.flag = False
|
||||||
|
|
||||||
|
|
||||||
def test_request_tracker():
|
def test_request_tracer():
|
||||||
tracker = Tracer()
|
tracker = Tracer()
|
||||||
tracker.new_requests_event = SampleEvent()
|
tracker.new_requests_event = SampleEvent()
|
||||||
stream_1 = tracker.add_request(1)
|
stream_1 = tracker.add_request(1)
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new = tracker.get_new_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.flag
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == 1
|
assert new[0]["request_id"] == 1
|
||||||
assert not finished
|
|
||||||
assert not stream_1.finished
|
assert not stream_1.finished
|
||||||
|
|
||||||
stream_2 = tracker.add_request(2)
|
stream_2 = tracker.add_request(2)
|
||||||
stream_3 = tracker.add_request(3)
|
stream_3 = tracker.add_request(3)
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new = tracker.get_new_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.flag
|
||||||
assert len(new) == 2
|
assert len(new) == 2
|
||||||
assert new[0]["request_id"] == 2
|
assert new[0]["request_id"] == 2
|
||||||
assert new[1]["request_id"] == 3
|
assert new[1]["request_id"] == 3
|
||||||
assert not finished
|
|
||||||
assert not stream_2.finished
|
assert not stream_2.finished
|
||||||
assert not stream_3.finished
|
assert not stream_3.finished
|
||||||
|
|
||||||
@ -45,28 +43,21 @@ def test_request_tracker():
|
|||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.flag
|
||||||
|
|
||||||
tracker.abort_request(1)
|
tracker.abort_request(1)
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new = tracker.get_new_requests()
|
||||||
assert len(finished) == 1
|
|
||||||
assert 1 in finished
|
|
||||||
assert not new
|
assert not new
|
||||||
assert stream_1.finished
|
|
||||||
|
|
||||||
stream_4 = tracker.add_request(4)
|
stream_4 = tracker.add_request(4)
|
||||||
tracker.abort_request(4)
|
tracker.abort_request(4)
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new = tracker.get_new_requests()
|
||||||
assert len(finished) == 1
|
|
||||||
assert 4 in finished
|
|
||||||
assert not new
|
assert not new
|
||||||
assert stream_4.finished
|
assert stream_4.finished
|
||||||
|
|
||||||
stream_5 = tracker.add_request(5)
|
stream_5 = tracker.add_request(5)
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.flag
|
||||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new = tracker.get_new_requests()
|
||||||
assert not tracker.new_requests_event.flag
|
assert not tracker.new_requests_event.flag
|
||||||
assert len(finished) == 1
|
|
||||||
assert 2 in finished
|
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == 5
|
assert new[0]["request_id"] == 5
|
||||||
assert stream_2.finished
|
assert stream_2.finished
|
||||||
@ -74,4 +65,4 @@ def test_request_tracker():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_request_tracker()
|
test_request_tracer()
|
@ -4,11 +4,13 @@ import sys
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
MAX_WAITING_TIME = 300
|
MAX_WAITING_TIME = 300
|
||||||
|
|
||||||
|
ray = pytest.importorskip("ray")
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user