(feat): dbgpts module add

This commit is contained in:
途杨
2024-08-28 15:41:45 +08:00
parent 1cb7e35295
commit 5f1b142f4e
437 changed files with 2636 additions and 82673 deletions

View File

@@ -264,6 +264,7 @@ class MultiAgents(BaseComponent, ABC):
# init agent memory
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
task = None
try:
task = asyncio.create_task(
multi_agents.agent_team_chat_new(
@@ -364,6 +365,7 @@ class MultiAgents(BaseComponent, ABC):
current_message.add_user_message(user_query)
agent_conv_id = None
agent_task = None
default_final_message = None
try:
async for task, chunk, agent_conv_id in multi_agents.agent_chat_v2(
conv_uid,
@@ -377,6 +379,7 @@ class MultiAgents(BaseComponent, ABC):
**ext_info,
):
agent_task = task
default_final_message = chunk
yield chunk
except asyncio.CancelledError:
@@ -390,10 +393,13 @@ class MultiAgents(BaseComponent, ABC):
raise
finally:
logger.info(f"save agent chat info{conv_uid}")
if agent_conv_id:
if agent_task:
final_message = await self.stable_message(agent_conv_id)
if final_message:
current_message.add_view_message(final_message)
else:
current_message.add_view_message(default_final_message)
current_message.end_current_round()
current_message.save_to_storage()

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,232 @@
import logging
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result, blocking_func_to_async
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
logger = logging.getLogger(__name__)
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new DbgptsHub entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a DbgptsHub entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query DbgptsHub entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query DbgptsHub entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
try:
return Result.succ(service.get_list_by_page(request, page, page_size))
except Exception as e:
logger.exception("query_page exception!")
return Result.failed(msg=str(e))
@router.post("/source/refresh", response_model=Result[str])
async def source_refresh(
service: Service = Depends(get_service),
):
logger.info(f"source_refresh")
try:
await blocking_func_to_async(
global_system_app,
service.refresh_hub_from_git,
)
return Result.succ(None)
except Exception as e:
logger.error("Dbgpts hub source refresh Error!", e)
return Result.failed(err_code="E0020", msg=f"Dbgpts Hub refresh Error! {e}")
@router.post("/install", response_model=Result[str])
async def install(request: ServeRequest):
logger.info(f"dbgpts install:{request.name},{request.type}")
try:
from dbgpt.serve.dbgpts.my.config import (
SERVE_SERVICE_COMPONENT_NAME as MY_GPTS_SERVICE_COMPONENT,
)
from dbgpt.serve.dbgpts.my.service.service import Service as MyGptsService
mygpts_service: MyGptsService = global_system_app.get_component(
MY_GPTS_SERVICE_COMPONENT, MyGptsService
)
await blocking_func_to_async(
global_system_app,
mygpts_service.install_gpts,
name=request.name,
type=request.type,
repo=request.storage_channel,
dbgpt_path=request.storage_url,
user_name=None,
sys_code=None,
)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.failed(err_code="E0021", msg=f"Plugin Install Error {e}")
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,33 @@
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""DbgptsHub request model"""
id: Optional[int] = Field(None, description="id")
name: Optional[str] = Field(None, description="Dbgpts name")
type: Optional[str] = Field(None, description="Dbgpts type")
version: Optional[str] = Field(None, description="Dbgpts version")
description: Optional[str] = Field(None, description="Dbgpts description")
author: Optional[str] = Field(None, description="Dbgpts author")
email: Optional[str] = Field(None, description="Dbgpts email")
storage_channel: Optional[str] = Field(None, description="Dbgpts storage channel")
storage_url: Optional[str] = Field(None, description="Dbgpts storage url")
download_param: Optional[str] = Field(None, description="Dbgpts download param")
installed: Optional[int] = Field(None, description="Dbgpts installed")
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
class ServerResponse(ServeRequest):
gmt_created: Optional[str] = Field(None, description="Dbgpts create time")
gmt_modified: Optional[str] = Field(None, description="Dbgpts upload time")

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "dbgpts_hub"
SERVE_APP_NAME = "dbgpt_serve_dbgpts_hub"
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsHub"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_hub."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpts_hub"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,111 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
name = Column(String(255), unique=True, nullable=False, comment="dbgpts name")
description = Column(String(255), nullable=False, comment="dbgpts description")
author = Column(String(255), nullable=True, comment="dbgpts author")
email = Column(String(255), nullable=True, comment="dbgpts author email")
type = Column(String(255), comment="dbgpts type")
version = Column(String(255), comment="dbgpts version")
storage_channel = Column(String(255), comment="dbgpts storage channel")
storage_url = Column(String(255), comment="dbgpts download url")
download_param = Column(String(255), comment="dbgpts download param")
gmt_created = Column(
DateTime,
default=datetime.utcnow,
comment="plugin upload time",
)
gmt_modified = Column(
DateTime,
default=datetime.now,
onupdate=datetime.utcnow,
comment="Record update time",
)
installed = Column(Integer, default=False, comment="plugin already installed count")
UniqueConstraint("name", "type", name="uk_dbgpts")
Index("idx_q_type", "type")
def __repr__(self):
return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Dbgpts"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeEntity(**request_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return ServeRequest(
id=entity.id,
name=entity.name,
description=entity.description,
author=entity.author,
email=entity.email,
type=entity.type,
version=entity.version,
storage_channel=entity.storage_channel,
storage_url=entity.storage_url,
download_param=entity.download_param,
installed=entity.installed,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
request = self.to_request(entity)
return ServerResponse(
**request.to_dict(),
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

View File

@@ -0,0 +1,63 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/dbgpts/hub",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,214 @@
import logging
import re
from typing import List, Optional, Tuple
from dbgpt.agent import PluginStorageType
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.dbgpts.repo import _install_default_repos_if_no_repos, list_dbgpts
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for DbgptsHub"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a DbgptsHub entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {"id": request.id}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a DbgptsHub entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a DbgptsHub entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of DbgptsHub entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of DbgptsHub entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = ServeRequest(
name=request.name,
type=request.type,
version=request.version,
description=request.description,
author=request.author,
storage_channel=request.storage_channel,
storage_url=request.storage_url,
installed=request.installed,
)
return self.dao.get_list_page(query_request, page, page_size)
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
_install_default_repos_if_no_repos()
data: List[Tuple[str, str, str, str]] = list_dbgpts()
from dbgpt.util.dbgpts.base import get_repo_path
from dbgpt.util.dbgpts.loader import (
BasePackage,
InstalledPackage,
parse_package_metadata,
)
try:
for repo, package, name, gpts_path in data:
try:
if not name:
logger.info(
f"dbgpts error repo:{repo}, package:{package}, name:{name}, gpts_path:{gpts_path}"
)
continue
old_hub_info = self.get(ServeRequest(name=name, type=package))
base_package: BasePackage = parse_package_metadata(
InstalledPackage(
name=name,
repo=repo,
root=str(gpts_path),
package=package,
)
)
if old_hub_info:
self.dao.update(
query_request=ServeRequest(
name=old_hub_info.name, type=old_hub_info.type
),
update_request=ServeRequest(
version=base_package.version,
description=base_package.description,
),
)
else:
request = ServeRequest()
request.type = package
request.name = name
request.storage_channel = repo
request.storage_url = str(gpts_path)
request.author = self._get_dbgpts_author(base_package.authors)
request.email = self._get_dbgpts_email(base_package.authors)
request.download_param = None
request.installed = 0
request.version = base_package.version
request.description = base_package.description
self.create(request)
except Exception as e:
logger.warning(
f"Load from git failed repo:{repo}, package:{package}, name:{name}, gpts_path:{gpts_path}",
e,
)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
def _get_dbgpts_author(self, authors):
pattern = r"(.+?)<"
names = []
for item in authors:
names.extend(re.findall(pattern, item))
return ",".join(names)
def _get_dbgpts_email(self, authors):
pattern = r"<(.*?)>"
emails: List[str] = []
for item in authors:
emails.extend(re.findall(pattern, item))
return ",".join(emails)

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,103 @@
# import pytest
#
# from dbgpt.storage.metadata import db
#
# from ..api.schemas import ServeRequest, ServerResponse
# from ..config import ServeConfig
# from ..models.models import ServeDao, ServeEntity
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
# db.init_db("sqlite:///:memory:")
# db.create_all()
#
# yield
#
#
# @pytest.fixture
# def server_config():
# # TODO : build your server config
# return ServeConfig()
#
#
# @pytest.fixture
# def dao(server_config):
# return ServeDao(server_config)
#
#
# @pytest.fixture
# def default_entity_dict():
# # TODO: build your default entity dict
# return {}
#
#
# def test_table_exist():
# assert ServeEntity.__tablename__ in db.metadata.tables
#
#
# def test_entity_create(default_entity_dict):
# with db.session() as session:
# entity = ServeEntity(**default_entity_dict)
# session.add(entity)
#
#
# def test_entity_unique_key(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_get(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_update(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_delete(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_all():
# # TODO: implement your test case
# pass
#
#
# def test_dao_create(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
# assert res is not None
#
#
# def test_dao_get_one(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
#
#
# def test_get_dao_get_list(dao):
# # TODO: implement your test case
# pass
#
#
# def test_dao_update(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_delete(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_get_list_page(dao):
# # TODO: implement your test case
# pass
#
#
# # Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,203 @@
import logging
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result, blocking_func_to_async
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
logger = logging.getLogger(__name__)
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new DbgptsMy entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a DbgptsMy entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query DbgptsMy entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query DbgptsMy entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
@router.post("/uninstall", response_model=Result[str])
async def agent_uninstall(
name: str,
type=str,
user: Optional[str] = None,
service: Service = Depends(get_service),
):
logger.info(f"dbgpts uninstall:{name},{user}")
try:
await blocking_func_to_async(
global_system_app,
service.uninstall_gpts,
name=name,
type=type,
user_name=user,
)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.failed(err_code="E0022", msg=f"Plugin Uninstall Error {e}")
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,31 @@
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""DbgptsMy request model"""
id: Optional[int] = Field(None, description="id")
user_name: Optional[str] = Field(None, description="My gpts user name")
sys_code: Optional[str] = Field(None, description="My gpts sys code")
name: Optional[str] = Field(None, description="My gpts name")
file_name: Optional[str] = Field(None, description="My gpts file name")
type: Optional[str] = Field(None, description="My gpts type")
version: Optional[str] = Field(None, description="My gpts version")
use_count: Optional[int] = Field(None, description="My gpts use count")
succ_count: Optional[int] = Field(None, description="My gpts succ count")
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
class ServerResponse(ServeRequest):
gmt_created: Optional[str] = Field(None, description="Dbgpts create time")
gmt_modified: Optional[str] = Field(None, description="Dbgpts upload time")

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "dbgpts_my"
SERVE_APP_NAME = "dbgpt_serve_dbgpts_my"
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsMy"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_my."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpts_my"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,109 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="autoincrement id")
name = Column(String(255), unique=True, nullable=False, comment="gpts name")
type = Column(String(255), nullable=False, comment="gpts type")
version = Column(String(255), nullable=False, comment="gpts version")
user_name = Column(String(255), nullable=True, comment="user name")
file_name = Column(String(255), nullable=True, comment="gpts package file name")
use_count = Column(
Integer, nullable=True, default=0, comment="gpts total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="gpts total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.utcnow, comment="gpts install time")
gmt_modified = Column(
DateTime,
default=datetime.now,
onupdate=datetime.utcnow,
comment="Record update time",
)
UniqueConstraint("user_code", "name", name="uk_name")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for MyDbgpts"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[MyGptsServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeEntity(**request_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return ServeRequest(
id=entity.id,
user_name=entity.user_name,
sys_code=entity.sys_code,
name=entity.name,
file_name=entity.file_name,
type=entity.type,
version=entity.version,
use_count=entity.use_count,
succ_count=entity.succ_count,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
gmt_created_str = (
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_created
else ""
)
gmt_modified_str = (
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_modified
else ""
)
request = self.to_request(entity)
return ServerResponse(
**request.to_dict(),
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

View File

@@ -0,0 +1,63 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/dbgpts/my",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,191 @@
import logging
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.dbgpts.base import INSTALL_DIR
from dbgpt.util.dbgpts.repo import (
copy_and_install,
inner_copy_and_install,
inner_uninstall,
install,
uninstall,
)
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for DbgptsMy"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a DbgptsMy entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = {"id": request.id}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a DbgptsMy entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a DbgptsMy entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
self.dao.delete(request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of DbgptsMy entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of DbgptsMy entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
def install_gpts(
self,
name: str,
type: str,
repo: str,
dbgpt_path: str,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
):
logger.info(f"install_gpts {name}")
# install(name, repo)
try:
from pathlib import Path
inner_copy_and_install(repo, name, Path(dbgpt_path))
except Exception as e:
logger.exception(f"install_gpts failed!{str(e)}")
raise ValueError(f"Install dbgpts [{type}:{name}] Failed! {str(e)}", e)
from dbgpt.util.dbgpts.base import get_repo_path
from dbgpt.util.dbgpts.loader import (
BasePackage,
InstalledPackage,
parse_package_metadata,
)
base_package: BasePackage = parse_package_metadata(
InstalledPackage(
name=name,
repo=repo,
root=dbgpt_path,
package=type,
)
)
dbgpts_entity = self.get(ServeRequest(name=name, type=type))
if not dbgpts_entity:
request = ServeRequest()
request.name = name
request.user_name = user_name
request.sys_code = sys_code
request.type = type
request.file_name = str(INSTALL_DIR / name)
request.version = base_package.version
return self.create(request)
else:
dbgpts_entity.version = base_package.version
return self.update(ServeRequest(**dbgpts_entity.to_dict()))
def uninstall_gpts(
self,
name: str,
type: str,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
):
logger.info(f"install_gpts {name}")
try:
inner_uninstall(name)
except Exception as e:
logger.warning(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e)
raise ValueError(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e)
self.delete(ServeRequest(name=name, type=type))

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,103 @@
# import pytest
#
# from dbgpt.storage.metadata import db
#
# from ..api.schemas import ServeRequest, ServerResponse
# from ..config import ServeConfig
# from ..models.models import ServeDao, ServeEntity
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
# db.init_db("sqlite:///:memory:")
# db.create_all()
#
# yield
#
#
# @pytest.fixture
# def server_config():
# # TODO : build your server config
# return ServeConfig()
#
#
# @pytest.fixture
# def dao(server_config):
# return ServeDao(server_config)
#
#
# @pytest.fixture
# def default_entity_dict():
# # TODO: build your default entity dict
# return {}
#
#
# def test_table_exist():
# assert ServeEntity.__tablename__ in db.metadata.tables
#
#
# def test_entity_create(default_entity_dict):
# with db.session() as session:
# entity = ServeEntity(**default_entity_dict)
# session.add(entity)
#
#
# def test_entity_unique_key(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_get(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_update(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_delete(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_all():
# # TODO: implement your test case
# pass
#
#
# def test_dao_create(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
# assert res is not None
#
#
# def test_dao_get_one(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
#
#
# def test_get_dao_get_list(dao):
# # TODO: implement your test case
# pass
#
#
# def test_dao_update(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_delete(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_get_list_page(dao):
# # TODO: implement your test case
# pass
#
#
# # Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -623,7 +623,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:

View File

@@ -15,7 +15,6 @@ from ..api.schemas import ServeRequest
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
zip_buffer = io.BytesIO()
flow_name = flow.name
flow_label = flow.label