[Inference] Fix bugs and docs for feat/online-server (#5598)

* fix test bugs

* add do sample test

* del useless lines

* fix comments

* fix tests

* delete version tag

* delete version tag

* add

* del test sever

* fix test

* fix

* Revert "add"

This reverts commit b9305fb024.
This commit is contained in:
Jianghai
2024-05-08 15:14:06 +08:00
committed by CjhHa1
parent 7bbb28e48b
commit 61a1b2e798
12 changed files with 98 additions and 172 deletions

View File

@@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class SequenceTpye:
class MockSequence:
request_id: int
@@ -20,7 +20,11 @@ class MockEngine:
async def async_step(self):
self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id):
self.request_id = request_id
@@ -37,14 +41,14 @@ class MockEngine:
self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine):
class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine = MockAsyncInferenceEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
@@ -74,7 +78,3 @@ async def test_new_requests_event():
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

View File

@@ -1,6 +1,6 @@
import pytest
from colossalai.inference.core.async_engine import RequestTracker
from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence
@@ -15,27 +15,25 @@ class SampleEvent:
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
def test_request_tracer():
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished
assert not stream_3.finished
@@ -45,28 +43,21 @@ def test_request_tracker():
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 1 in finished
new = tracker.get_new_requests()
assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 4 in finished
new = tracker.get_new_requests()
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests()
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
@@ -74,4 +65,4 @@ def test_request_tracker():
if __name__ == "__main__":
test_request_tracker()
test_request_tracer()

View File

@@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
@parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
"test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()

View File

@@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
)
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
@@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
@@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,

View File

@@ -1,79 +0,0 @@
# inspired by vLLM
import subprocess
import sys
import time
import pytest
import ray
import requests
MAX_WAITING_TIME = 300
pytestmark = pytest.mark.asyncio
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
self.proc = subprocess.Popen(
["python3", "-m", "colossalai.inference.server.api_server"] + args,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get("http://localhost:8000/v0/models").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_WAITING_TIME:
raise RuntimeError("Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote(
[
"--model",
"/home/chenjianghai/data/llama-7b-hf",
]
)
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
async def test_completion(server):
data = {"prompt": "How are you?", "stream": "False"}
response = await server.post("v1/completion", json=data)
assert response is not None
async def test_chat(server):
messages = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
]
data = {"messages": messages, "stream": "False"}
response = await server.post("v1/chat", data)
assert response is not None
if __name__ == "__main__":
pytest.main([__file__])