Files
DB-GPT/pilot/model/cluster/apiserver/tests/test_api.py
2023-11-02 20:39:14 +08:00

249 lines
7.8 KiB
Python

import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from pilot.component import SystemApp
from pilot.utils.openai_utils import chat_completion_stream, chat_completion
from pilot.model.cluster.apiserver.api import (
api_settings,
initialize_apiserver,
ModelList,
UsageInfo,
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
DeltaMessage,
)
from pilot.model.cluster.tests.conftest import _new_cluster
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
@pytest_asyncio.fixture
async def system_app():
return SystemApp(app)
@pytest_asyncio.fixture
async def client(request, system_app: SystemApp):
param = getattr(request, "param", {})
api_keys = param.get("api_keys", [])
client_api_key = param.get("client_api_key")
if "num_workers" not in param:
param["num_workers"] = 2
if "api_keys" in param:
del param["api_keys"]
headers = {}
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
print(f"param: {param}")
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client
@pytest.mark.asyncio
async def test_get_all_models(client: AsyncClient):
res = await client.get("/api/v1/models")
res.status_code == 200
model_lists = ModelList.parse_obj(res.json())
print(f"model list json: {res.json()}")
assert model_lists.object == "list"
assert len(model_lists.data) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["client"],
)
async def test_chat_completions(client: AsyncClient, expected_messages):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
full_text = ""
async for text in chat_completion_stream(
"/api/v1/chat/completions", chat_data, client
):
full_text += text
assert full_text == expected_messages
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_no_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
one_res = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=expected_messages),
finish_reason="stop",
)
data = ChatCompletionResponse(
model=model_name, choices=[one_res], usage=UsageInfo()
)
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# Mock http request
mocked.post(
"http://test/api/v1/chat/completions", status=200, body=mock_message
)
completion = await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
)
assert completion.choices[0].message.content == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=expected_messages),
finish_reason="stop",
)
chunk = ChatCompletionStreamResponse(
id=0, choices=[choice_data], model=model_name
)
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
mocked.post(
"http://test/api/v1/chat/completions",
status=200,
body=mock_message,
content_type="text/event-stream",
)
stream_stream_resp = ""
async for stream_resp in await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True,
):
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
assert stream_stream_resp == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, api_key_is_error",
[
(
{
"stream_messags": ["Hello", " world."],
"api_keys": ["abc", "xx"],
"client_api_key": "abc",
},
"Hello world.",
False,
),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
"你好,我是张三。",
True,
),
(
{
"stream_messags": ["你好,我是", "张三。"],
"api_keys": ["abc", "xx"],
"client_api_key": "error_api_key",
},
"你好,我是张三。",
True,
),
],
indirect=["client"],
)
async def test_chat_completions_with_api_keys(
client: AsyncClient, expected_messages: str, api_key_is_error: bool
):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
if api_key_is_error:
with pytest.raises(HTTPError):
await chat_completion("/api/v1/chat/completions", chat_data, client)
else:
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)