feat(model): Support database model registry (#1656)

This commit is contained in:
Fangyin Cheng
2024-06-24 19:07:10 +08:00
committed by GitHub
parent c57ee0289b
commit 47d205f676
35 changed files with 2014 additions and 792 deletions

View File

@@ -1,22 +1,11 @@
import importlib.metadata as metadata
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from httpx import ASGITransport, AsyncClient, HTTPError
from dbgpt.component import SystemApp
from dbgpt.model.cluster.apiserver.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ModelList,
UsageInfo,
api_settings,
initialize_apiserver,
)
@@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp):
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with AsyncClient(
transport=ASGITransport(app), base_url="http://test", headers=headers
) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client
@@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages):
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)
@@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream(
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
"你好,我是张三。",
"abc",
),
],
indirect=["client"],
)