mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 22:51:24 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/model/cluster/apiserver/__init__.py
Normal file
0
dbgpt/model/cluster/apiserver/__init__.py
Normal file
438
dbgpt/model/cluster/apiserver/api.py
Normal file
438
dbgpt/model/cluster/apiserver/api.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
||||
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any, Generator
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import shortuuid
|
||||
import json
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastchat.protocol.api_protocol import (
|
||||
APIChatCompletionRequest,
|
||||
)
|
||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
||||
from fastchat.constants import ErrorCode
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.model.base import ModelInstance
|
||||
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from dbgpt.model.cluster import ModelRegistry
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIServerException(Exception):
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class APISettings(BaseSettings):
|
||||
api_keys: Optional[List[str]] = None
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
) -> str:
|
||||
if api_settings.api_keys:
|
||||
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
class APIServer(BaseComponent):
|
||||
name = ComponentType.MODEL_API_SERVER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def get_worker_manager(self) -> WorkerManager:
|
||||
"""Get the worker manager component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get worker manager component instance
|
||||
"""
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
|
||||
)
|
||||
return worker_manager
|
||||
|
||||
def get_model_registry(self) -> ModelRegistry:
|
||||
"""Get the model registry component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get model registry component instance
|
||||
"""
|
||||
|
||||
controller = self.system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry
|
||||
)
|
||||
if not controller:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
|
||||
)
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
message = f"Only {models} allowed now, your model {model_name}"
|
||||
else:
|
||||
message = f"No models allowed now, your model {model_name}"
|
||||
raise APIServerException(ErrorCode.INVALID_MODEL, message)
|
||||
return model_instances
|
||||
|
||||
async def get_available_models(self) -> ModelList:
|
||||
"""Return available models
|
||||
|
||||
Just include LLM and embedding models.
|
||||
|
||||
Returns:
|
||||
List[ModelList]: The list of models.
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
model_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
model_name_set = set()
|
||||
for inst in model_instances:
|
||||
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
|
||||
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
|
||||
model_name_set.add(name)
|
||||
models = list(model_name_set)
|
||||
models.sort()
|
||||
# TODO: return real model permission details
|
||||
model_cards = []
|
||||
for m in models:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> Generator[str, Any, None]:
|
||||
"""Chat stream completion generator
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
for i in range(n):
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_generate(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate completion
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
chat_completions = []
|
||||
for i in range(n):
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
chat_completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*chat_completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role="assistant", content=model_output.text),
|
||||
finish_reason=model_output.finish_reason or "stop",
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.parse_obj(model_output.usage)
|
||||
for usage_key, usage_value in task_usage.dict().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
|
||||
)
|
||||
if not api_server:
|
||||
global_system_app.register(APIServer)
|
||||
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
|
||||
return await api_server.get_available_models()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_chat_completion(
|
||||
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
params = {
|
||||
"model": request.model,
|
||||
"messages": ModelMessage.to_dict_list(
|
||||
ModelMessage.from_openai_messages(request.messages)
|
||||
),
|
||||
"echo": False,
|
||||
}
|
||||
if request.temperature:
|
||||
params["temperature"] = request.temperature
|
||||
if request.top_p:
|
||||
params["top_p"] = request.top_p
|
||||
if request.max_tokens:
|
||||
params["max_new_tokens"] = request.max_tokens
|
||||
if request.stop:
|
||||
params["stop"] = request.stop
|
||||
if request.user:
|
||||
params["user"] = request.user
|
||||
|
||||
# TODO check token length
|
||||
if request.stream:
|
||||
generator = api_server.chat_completion_stream_generator(
|
||||
request.model, params, request.n
|
||||
)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
if not system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
):
|
||||
# Register model registry if not exist
|
||||
registry = ModelRegistryClient(controller_addr)
|
||||
registry.name = ComponentType.MODEL_REGISTRY.value
|
||||
system_app.register_instance(registry)
|
||||
|
||||
registry = system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
|
||||
# Register worker manager component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
or_register_component=_DefaultWorkerManagerFactory,
|
||||
worker_manager=worker_manager,
|
||||
)
|
||||
# Register api server component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
|
||||
)
|
||||
|
||||
|
||||
def initialize_apiserver(
|
||||
controller_addr: str,
|
||||
app=None,
|
||||
system_app: SystemApp = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
global_system_app = system_app
|
||||
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
async def validation_apiserver_exception_handler(request, exc: APIServerException):
|
||||
return create_error_response(exc.code, exc.message)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||
|
||||
_initialize_all(controller_addr, system_app)
|
||||
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "apiserver_"
|
||||
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
|
||||
ModelAPIServerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
setup_logging(
|
||||
"dbgpt",
|
||||
logging_level=apiserver_params.log_level,
|
||||
logger_filename=apiserver_params.log_file,
|
||||
)
|
||||
api_keys = None
|
||||
if apiserver_params.api_keys:
|
||||
api_keys = apiserver_params.api_keys.strip().split(",")
|
||||
|
||||
initialize_apiserver(
|
||||
apiserver_params.controller_addr,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_apiserver()
|
0
dbgpt/model/cluster/apiserver/tests/__init__.py
Normal file
0
dbgpt/model/cluster/apiserver/tests/__init__.py
Normal file
248
dbgpt/model/cluster/apiserver/tests/test_api.py
Normal file
248
dbgpt/model/cluster/apiserver/tests/test_api.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aioresponses import aioresponses
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient, HTTPError
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.openai_utils import chat_completion_stream, chat_completion
|
||||
|
||||
from dbgpt.model.cluster.apiserver.api import (
|
||||
api_settings,
|
||||
initialize_apiserver,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
)
|
||||
from dbgpt.model.cluster.tests.conftest import _new_cluster
|
||||
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def system_app():
|
||||
return SystemApp(app)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(request, system_app: SystemApp):
|
||||
param = getattr(request, "param", {})
|
||||
api_keys = param.get("api_keys", [])
|
||||
client_api_key = param.get("client_api_key")
|
||||
if "num_workers" not in param:
|
||||
param["num_workers"] = 2
|
||||
if "api_keys" in param:
|
||||
del param["api_keys"]
|
||||
headers = {}
|
||||
if client_api_key:
|
||||
headers["Authorization"] = "Bearer " + client_api_key
|
||||
print(f"param: {param}")
|
||||
if api_settings:
|
||||
# Clear global api keys
|
||||
api_settings.api_keys = []
|
||||
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
|
||||
async with _new_cluster(**param) as cluster:
|
||||
worker_manager, model_registry = cluster
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
system_app.register_instance(model_registry)
|
||||
# print(f"Instances {model_registry.registry}")
|
||||
initialize_apiserver(None, app, system_app, api_keys=api_keys)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_models(client: AsyncClient):
|
||||
res = await client.get("/api/v1/models")
|
||||
res.status_code == 200
|
||||
model_lists = ModelList.parse_obj(res.json())
|
||||
print(f"model list json: {res.json()}")
|
||||
assert model_lists.object == "list"
|
||||
assert len(model_lists.data) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions(client: AsyncClient, expected_messages):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
"/api/v1/chat/completions", chat_data, client
|
||||
):
|
||||
full_text += text
|
||||
assert full_text == expected_messages
|
||||
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_no_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
one_res = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
data = ChatCompletionResponse(
|
||||
model=model_name, choices=[one_res], usage=UsageInfo()
|
||||
)
|
||||
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# Mock http request
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions", status=200, body=mock_message
|
||||
)
|
||||
completion = await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
)
|
||||
assert completion.choices[0].message.content == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=0, choices=[choice_data], model=model_name
|
||||
)
|
||||
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions",
|
||||
status=200,
|
||||
body=mock_message,
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
|
||||
stream_stream_resp = ""
|
||||
async for stream_resp in await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
stream=True,
|
||||
):
|
||||
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
|
||||
assert stream_stream_resp == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, api_key_is_error",
|
||||
[
|
||||
(
|
||||
{
|
||||
"stream_messags": ["Hello", " world."],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "abc",
|
||||
},
|
||||
"Hello world.",
|
||||
False,
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"stream_messags": ["你好,我是", "张三。"],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "error_api_key",
|
||||
},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_api_keys(
|
||||
client: AsyncClient, expected_messages: str, api_key_is_error: bool
|
||||
):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
if api_key_is_error:
|
||||
with pytest.raises(HTTPError):
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
else:
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
Reference in New Issue
Block a user