mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
This commit is contained in:
@@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceTpye:
|
||||
class MockSequence:
|
||||
request_id: int
|
||||
|
||||
|
||||
@@ -20,7 +20,11 @@ class MockEngine:
|
||||
|
||||
async def async_step(self):
|
||||
self.step_calls += 1
|
||||
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
|
||||
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||
|
||||
def add_single_request(self, **kwargs):
|
||||
del kwargs
|
||||
self.add_request_calls += 1
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
@@ -37,14 +41,14 @@ class MockEngine:
|
||||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncInferenceEngine):
|
||||
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine = MockAsyncInferenceEngine()
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
@@ -74,7 +78,3 @@ async def test_new_requests_event():
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_new_requests_event()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import RequestTracker
|
||||
from colossalai.inference.core.async_engine import Tracer
|
||||
from colossalai.inference.struct import Sequence
|
||||
|
||||
|
||||
@@ -15,27 +15,25 @@ class SampleEvent:
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
def test_request_tracer():
|
||||
tracker = Tracer()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 1
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request(2)
|
||||
stream_3 = tracker.add_request(3)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == 2
|
||||
assert new[1]["request_id"] == 3
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
@@ -45,28 +43,21 @@ def test_request_tracker():
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request(1)
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 1 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request(4)
|
||||
tracker.abort_request(4)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 4 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request(5)
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert 2 in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 5
|
||||
assert stream_2.finished
|
||||
@@ -74,4 +65,4 @@ def test_request_tracker():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_request_tracker()
|
||||
test_request_tracer()
|
@@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
|
||||
|
||||
|
||||
@parameterize(
|
||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"max_batch_size": 8,
|
||||
"max_output_len": 512,
|
||||
"max_input_len": 64,
|
||||
"do_sample": False,
|
||||
}
|
||||
],
|
||||
)
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_len = test_config["max_input_len"]
|
||||
max_output_len = test_config["max_output_len"]
|
||||
do_sample = test_config["do_sample"]
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
@@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||
"介绍一下武汉,",
|
||||
@@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||
)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
@@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
dtype="fp32",
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
@@ -1,79 +0,0 @@
|
||||
# inspired by vLLM
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_WAITING_TIME = 300
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
def __init__(self, args):
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "colossalai.inference.server.api_server"] + args,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get("http://localhost:8000/v0/models").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_WAITING_TIME:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote(
|
||||
[
|
||||
"--model",
|
||||
"/home/chenjianghai/data/llama-7b-hf",
|
||||
]
|
||||
)
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
async def test_completion(server):
|
||||
data = {"prompt": "How are you?", "stream": "False"}
|
||||
response = await server.post("v1/completion", json=data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
async def test_chat(server):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
]
|
||||
data = {"messages": messages, "stream": "False"}
|
||||
response = await server.post("v1/chat", data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
Reference in New Issue
Block a user