mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Online Server] Chat Api for streaming and not streaming response (#5470)
* fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n
This commit is contained in:
79
tests/test_infer/test_server.py
Normal file
79
tests/test_infer/test_server.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# inspired by vLLM
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_WAITING_TIME = 300
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
def __init__(self, args):
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "colossalai.inference.server.api_server"] + args,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get("http://localhost:8000/v0/models").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_WAITING_TIME:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote(
|
||||
[
|
||||
"--model",
|
||||
"/home/chenjianghai/data/llama-7b-hf",
|
||||
]
|
||||
)
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
async def test_completion(server):
|
||||
data = {"prompt": "How are you?", "stream": "False"}
|
||||
response = await server.post("v1/completion", json=data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
async def test_chat(server):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
]
|
||||
data = {"messages": messages, "stream": "False"}
|
||||
response = await server.post("v1/chat", data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
Reference in New Issue
Block a user