mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
feat(model): Support database model registry (#1656)
This commit is contained in:
@@ -1,22 +1,11 @@
|
||||
import importlib.metadata as metadata
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aioresponses import aioresponses
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient, HTTPError
|
||||
from httpx import ASGITransport, AsyncClient, HTTPError
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.model.cluster.apiserver.api import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
api_settings,
|
||||
initialize_apiserver,
|
||||
)
|
||||
@@ -56,12 +45,13 @@ async def client(request, system_app: SystemApp):
|
||||
if api_settings:
|
||||
# Clear global api keys
|
||||
api_settings.api_keys = []
|
||||
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app), base_url="http://test", headers=headers
|
||||
) as client:
|
||||
async with _new_cluster(**param) as cluster:
|
||||
worker_manager, model_registry = cluster
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
system_app.register_instance(model_registry)
|
||||
# print(f"Instances {model_registry.registry}")
|
||||
initialize_apiserver(None, app, system_app, api_keys=api_keys)
|
||||
yield client
|
||||
|
||||
@@ -113,7 +103,11 @@ async def test_chat_completions(client: AsyncClient, expected_messages):
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
|
||||
"你好,我是张三。",
|
||||
"abc",
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
@@ -160,7 +154,11 @@ async def test_chat_completions_with_openai_lib_async_no_stream(
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]},
|
||||
"你好,我是张三。",
|
||||
"abc",
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
|
Reference in New Issue
Block a user