mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
(feat): dbgpts module add
This commit is contained in:
@@ -264,6 +264,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
# init agent memory
|
||||
agent_memory = self.get_or_build_agent_memory(conv_id, gpts_name)
|
||||
|
||||
task = None
|
||||
try:
|
||||
task = asyncio.create_task(
|
||||
multi_agents.agent_team_chat_new(
|
||||
@@ -364,6 +365,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
current_message.add_user_message(user_query)
|
||||
agent_conv_id = None
|
||||
agent_task = None
|
||||
default_final_message = None
|
||||
try:
|
||||
async for task, chunk, agent_conv_id in multi_agents.agent_chat_v2(
|
||||
conv_uid,
|
||||
@@ -377,6 +379,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
**ext_info,
|
||||
):
|
||||
agent_task = task
|
||||
default_final_message = chunk
|
||||
yield chunk
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -390,10 +393,13 @@ class MultiAgents(BaseComponent, ABC):
|
||||
raise
|
||||
finally:
|
||||
logger.info(f"save agent chat info!{conv_uid}")
|
||||
if agent_conv_id:
|
||||
if agent_task:
|
||||
final_message = await self.stable_message(agent_conv_id)
|
||||
if final_message:
|
||||
current_message.add_view_message(final_message)
|
||||
else:
|
||||
current_message.add_view_message(default_final_message)
|
||||
|
||||
current_message.end_current_round()
|
||||
current_message.save_to_storage()
|
||||
|
||||
|
2
dbgpt/serve/dbgpts/hub/__init__.py
Normal file
2
dbgpt/serve/dbgpts/hub/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_hub`
|
2
dbgpt/serve/dbgpts/hub/api/__init__.py
Normal file
2
dbgpt/serve/dbgpts/hub/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_hub`
|
232
dbgpt/serve/dbgpts/hub/api/endpoints.py
Normal file
232
dbgpt/serve/dbgpts/hub/api/endpoints.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Add your API endpoints here
|
||||
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key}
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Create a new DbgptsHub entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Update a DbgptsHub entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Query DbgptsHub entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get(request))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
page_size: Optional[int] = Query(default=20, description="page size"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[ServerResponse]]:
|
||||
"""Query DbgptsHub entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
try:
|
||||
return Result.succ(service.get_list_by_page(request, page, page_size))
|
||||
except Exception as e:
|
||||
logger.exception("query_page exception!")
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
|
||||
@router.post("/source/refresh", response_model=Result[str])
|
||||
async def source_refresh(
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
logger.info(f"source_refresh")
|
||||
try:
|
||||
await blocking_func_to_async(
|
||||
global_system_app,
|
||||
service.refresh_hub_from_git,
|
||||
)
|
||||
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Dbgpts hub source refresh Error!", e)
|
||||
return Result.failed(err_code="E0020", msg=f"Dbgpts Hub refresh Error! {e}")
|
||||
|
||||
|
||||
@router.post("/install", response_model=Result[str])
|
||||
async def install(request: ServeRequest):
|
||||
logger.info(f"dbgpts install:{request.name},{request.type}")
|
||||
|
||||
try:
|
||||
from dbgpt.serve.dbgpts.my.config import (
|
||||
SERVE_SERVICE_COMPONENT_NAME as MY_GPTS_SERVICE_COMPONENT,
|
||||
)
|
||||
from dbgpt.serve.dbgpts.my.service.service import Service as MyGptsService
|
||||
|
||||
mygpts_service: MyGptsService = global_system_app.get_component(
|
||||
MY_GPTS_SERVICE_COMPONENT, MyGptsService
|
||||
)
|
||||
|
||||
await blocking_func_to_async(
|
||||
global_system_app,
|
||||
mygpts_service.install_gpts,
|
||||
name=request.name,
|
||||
type=request.type,
|
||||
repo=request.storage_channel,
|
||||
dbgpt_path=request.storage_url,
|
||||
user_name=None,
|
||||
sys_code=None,
|
||||
)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
return Result.failed(err_code="E0021", msg=f"Plugin Install Error {e}")
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
global_system_app = system_app
|
33
dbgpt/serve/dbgpts/hub/api/schemas.py
Normal file
33
dbgpt/serve/dbgpts/hub/api/schemas.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
"""DbgptsHub request model"""
|
||||
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
name: Optional[str] = Field(None, description="Dbgpts name")
|
||||
type: Optional[str] = Field(None, description="Dbgpts type")
|
||||
version: Optional[str] = Field(None, description="Dbgpts version")
|
||||
description: Optional[str] = Field(None, description="Dbgpts description")
|
||||
author: Optional[str] = Field(None, description="Dbgpts author")
|
||||
email: Optional[str] = Field(None, description="Dbgpts email")
|
||||
storage_channel: Optional[str] = Field(None, description="Dbgpts storage channel")
|
||||
storage_url: Optional[str] = Field(None, description="Dbgpts storage url")
|
||||
download_param: Optional[str] = Field(None, description="Dbgpts download param")
|
||||
installed: Optional[int] = Field(None, description="Dbgpts installed")
|
||||
|
||||
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class ServerResponse(ServeRequest):
|
||||
gmt_created: Optional[str] = Field(None, description="Dbgpts create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="Dbgpts upload time")
|
22
dbgpt/serve/dbgpts/hub/config.py
Normal file
22
dbgpt/serve/dbgpts/hub/config.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
APP_NAME = "dbgpts_hub"
|
||||
SERVE_APP_NAME = "dbgpt_serve_dbgpts_hub"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsHub"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_hub."
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
# Database table name
|
||||
SERVER_APP_TABLE_NAME = "dbgpts_hub"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
1
dbgpt/serve/dbgpts/hub/dependencies.py
Normal file
1
dbgpt/serve/dbgpts/hub/dependencies.py
Normal file
@@ -0,0 +1 @@
|
||||
# Define your dependencies here
|
2
dbgpt/serve/dbgpts/hub/models/__init__.py
Normal file
2
dbgpt/serve/dbgpts/hub/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_hub`
|
111
dbgpt/serve/dbgpts/hub/models/models.py
Normal file
111
dbgpt/serve/dbgpts/hub/models/models.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
__tablename__ = SERVER_APP_TABLE_NAME
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
|
||||
name = Column(String(255), unique=True, nullable=False, comment="dbgpts name")
|
||||
description = Column(String(255), nullable=False, comment="dbgpts description")
|
||||
author = Column(String(255), nullable=True, comment="dbgpts author")
|
||||
email = Column(String(255), nullable=True, comment="dbgpts author email")
|
||||
type = Column(String(255), comment="dbgpts type")
|
||||
version = Column(String(255), comment="dbgpts version")
|
||||
storage_channel = Column(String(255), comment="dbgpts storage channel")
|
||||
storage_url = Column(String(255), comment="dbgpts download url")
|
||||
download_param = Column(String(255), comment="dbgpts download param")
|
||||
gmt_created = Column(
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
comment="plugin upload time",
|
||||
)
|
||||
gmt_modified = Column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="Record update time",
|
||||
)
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
UniqueConstraint("name", "type", name="uk_dbgpts")
|
||||
Index("idx_q_type", "type")
|
||||
|
||||
def __repr__(self):
|
||||
return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for Dbgpts"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[ServeRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
request.to_dict() if isinstance(request, ServeRequest) else request
|
||||
)
|
||||
entity = ServeEntity(**request_dict)
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: ServeEntity) -> ServeRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
|
||||
return ServeRequest(
|
||||
id=entity.id,
|
||||
name=entity.name,
|
||||
description=entity.description,
|
||||
author=entity.author,
|
||||
email=entity.email,
|
||||
type=entity.type,
|
||||
version=entity.version,
|
||||
storage_channel=entity.storage_channel,
|
||||
storage_url=entity.storage_url,
|
||||
download_param=entity.download_param,
|
||||
installed=entity.installed,
|
||||
)
|
||||
|
||||
def to_response(self, entity: ServeEntity) -> ServerResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
request = self.to_request(entity)
|
||||
|
||||
return ServerResponse(
|
||||
**request.to_dict(),
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
)
|
63
dbgpt/serve/dbgpts/hub/serve.py
Normal file
63
dbgpt/serve/dbgpts/hub/serve.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseServe):
|
||||
"""Serve component for DB-GPT"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/dbgpts/hub",
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
def on_init(self):
|
||||
"""Called when init the application.
|
||||
|
||||
You can do some initialization here. You can't get other components here because they may be not initialized yet
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from .models.models import ServeEntity
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
# TODO: Your code here
|
||||
self._db_manager = self.create_or_get_db_manager()
|
0
dbgpt/serve/dbgpts/hub/service/__init__.py
Normal file
0
dbgpt/serve/dbgpts/hub/service/__init__.py
Normal file
214
dbgpt/serve/dbgpts/hub/service/service.py
Normal file
214
dbgpt/serve/dbgpts/hub/service/service.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt.agent import PluginStorageType
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.dbgpts.repo import _install_default_repos_if_no_repos, list_dbgpts
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The service class for DbgptsHub"""
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def update(self, request: ServeRequest) -> ServerResponse:
|
||||
"""Update a DbgptsHub entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = {"id": request.id}
|
||||
return self.dao.update(query_request, update_request=request)
|
||||
|
||||
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
|
||||
"""Get a DbgptsHub entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_one(query_request)
|
||||
|
||||
def delete(self, request: ServeRequest) -> None:
|
||||
"""Delete a DbgptsHub entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
"""
|
||||
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = {
|
||||
# "id": request.id
|
||||
}
|
||||
self.dao.delete(query_request)
|
||||
|
||||
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
|
||||
"""Get a list of DbgptsHub entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_list(query_request)
|
||||
|
||||
def get_list_by_page(
|
||||
self, request: ServeRequest, page: int, page_size: int
|
||||
) -> PaginationResult[ServerResponse]:
|
||||
"""Get a list of DbgptsHub entities by page
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
query_request = ServeRequest(
|
||||
name=request.name,
|
||||
type=request.type,
|
||||
version=request.version,
|
||||
description=request.description,
|
||||
author=request.author,
|
||||
storage_channel=request.storage_channel,
|
||||
storage_url=request.storage_url,
|
||||
installed=request.installed,
|
||||
)
|
||||
|
||||
return self.dao.get_list_page(query_request, page, page_size)
|
||||
|
||||
def refresh_hub_from_git(
|
||||
self,
|
||||
github_repo: str = None,
|
||||
branch_name: str = "main",
|
||||
authorization: str = None,
|
||||
):
|
||||
logger.info("refresh_hub_by_git start!")
|
||||
_install_default_repos_if_no_repos()
|
||||
data: List[Tuple[str, str, str, str]] = list_dbgpts()
|
||||
|
||||
from dbgpt.util.dbgpts.base import get_repo_path
|
||||
from dbgpt.util.dbgpts.loader import (
|
||||
BasePackage,
|
||||
InstalledPackage,
|
||||
parse_package_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
for repo, package, name, gpts_path in data:
|
||||
try:
|
||||
if not name:
|
||||
logger.info(
|
||||
f"dbgpts error repo:{repo}, package:{package}, name:{name}, gpts_path:{gpts_path}"
|
||||
)
|
||||
continue
|
||||
old_hub_info = self.get(ServeRequest(name=name, type=package))
|
||||
base_package: BasePackage = parse_package_metadata(
|
||||
InstalledPackage(
|
||||
name=name,
|
||||
repo=repo,
|
||||
root=str(gpts_path),
|
||||
package=package,
|
||||
)
|
||||
)
|
||||
if old_hub_info:
|
||||
self.dao.update(
|
||||
query_request=ServeRequest(
|
||||
name=old_hub_info.name, type=old_hub_info.type
|
||||
),
|
||||
update_request=ServeRequest(
|
||||
version=base_package.version,
|
||||
description=base_package.description,
|
||||
),
|
||||
)
|
||||
else:
|
||||
request = ServeRequest()
|
||||
request.type = package
|
||||
request.name = name
|
||||
request.storage_channel = repo
|
||||
request.storage_url = str(gpts_path)
|
||||
request.author = self._get_dbgpts_author(base_package.authors)
|
||||
request.email = self._get_dbgpts_email(base_package.authors)
|
||||
|
||||
request.download_param = None
|
||||
request.installed = 0
|
||||
request.version = base_package.version
|
||||
request.description = base_package.description
|
||||
self.create(request)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Load from git failed repo:{repo}, package:{package}, name:{name}, gpts_path:{gpts_path}",
|
||||
e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
|
||||
|
||||
def _get_dbgpts_author(self, authors):
|
||||
pattern = r"(.+?)<"
|
||||
names = []
|
||||
for item in authors:
|
||||
names.extend(re.findall(pattern, item))
|
||||
return ",".join(names)
|
||||
|
||||
def _get_dbgpts_email(self, authors):
|
||||
pattern = r"<(.*?)>"
|
||||
emails: List[str] = []
|
||||
for item in authors:
|
||||
emails.extend(re.findall(pattern, item))
|
||||
return ",".join(emails)
|
0
dbgpt/serve/dbgpts/hub/tests/__init__.py
Normal file
0
dbgpt/serve/dbgpts/hub/tests/__init__.py
Normal file
124
dbgpt/serve/dbgpts/hub/tests/test_endpoints.py
Normal file
124
dbgpt/serve/dbgpts/hub/tests/test_endpoints.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import asystem_app, client
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_health(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
# TODO: add your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
103
dbgpt/serve/dbgpts/hub/tests/test_models.py
Normal file
103
dbgpt/serve/dbgpts/hub/tests/test_models.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# import pytest
|
||||
#
|
||||
# from dbgpt.storage.metadata import db
|
||||
#
|
||||
# from ..api.schemas import ServeRequest, ServerResponse
|
||||
# from ..config import ServeConfig
|
||||
# from ..models.models import ServeDao, ServeEntity
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def setup_and_teardown():
|
||||
# db.init_db("sqlite:///:memory:")
|
||||
# db.create_all()
|
||||
#
|
||||
# yield
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def server_config():
|
||||
# # TODO : build your server config
|
||||
# return ServeConfig()
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def dao(server_config):
|
||||
# return ServeDao(server_config)
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def default_entity_dict():
|
||||
# # TODO: build your default entity dict
|
||||
# return {}
|
||||
#
|
||||
#
|
||||
# def test_table_exist():
|
||||
# assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
#
|
||||
#
|
||||
# def test_entity_create(default_entity_dict):
|
||||
# with db.session() as session:
|
||||
# entity = ServeEntity(**default_entity_dict)
|
||||
# session.add(entity)
|
||||
#
|
||||
#
|
||||
# def test_entity_unique_key(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_get(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_update(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_delete(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_all():
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_create(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# req = ServeRequest(**default_entity_dict)
|
||||
# res: ServerResponse = dao.create(req)
|
||||
# assert res is not None
|
||||
#
|
||||
#
|
||||
# def test_dao_get_one(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# req = ServeRequest(**default_entity_dict)
|
||||
# res: ServerResponse = dao.create(req)
|
||||
#
|
||||
#
|
||||
# def test_get_dao_get_list(dao):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_update(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_delete(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_get_list_page(dao):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# # Add more test cases according to your own logic
|
78
dbgpt/serve/dbgpts/hub/tests/test_service.py
Normal file
78
dbgpt/serve/dbgpts/hub/tests/test_service.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
# ...
|
||||
pass
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
2
dbgpt/serve/dbgpts/my/__init__.py
Normal file
2
dbgpt/serve/dbgpts/my/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_my`
|
2
dbgpt/serve/dbgpts/my/api/__init__.py
Normal file
2
dbgpt/serve/dbgpts/my/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_my`
|
203
dbgpt/serve/dbgpts/my/api/endpoints.py
Normal file
203
dbgpt/serve/dbgpts/my/api/endpoints.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Add your API endpoints here
|
||||
|
||||
global_system_app: Optional[SystemApp] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key}
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Create a new DbgptsMy entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Update a DbgptsMy entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
"""Query DbgptsMy entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get(request))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
page_size: Optional[int] = Query(default=20, description="page size"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[ServerResponse]]:
|
||||
"""Query DbgptsMy entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
service (Service): The service
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.get_list_by_page(request, page, page_size))
|
||||
|
||||
|
||||
@router.post("/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(
|
||||
name: str,
|
||||
type=str,
|
||||
user: Optional[str] = None,
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
logger.info(f"dbgpts uninstall:{name},{user}")
|
||||
try:
|
||||
await blocking_func_to_async(
|
||||
global_system_app,
|
||||
service.uninstall_gpts,
|
||||
name=name,
|
||||
type=type,
|
||||
user_name=user,
|
||||
)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.failed(err_code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
global_system_app = system_app
|
31
dbgpt/serve/dbgpts/my/api/schemas.py
Normal file
31
dbgpt/serve/dbgpts/my/api/schemas.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
"""DbgptsMy request model"""
|
||||
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
user_name: Optional[str] = Field(None, description="My gpts user name")
|
||||
sys_code: Optional[str] = Field(None, description="My gpts sys code")
|
||||
name: Optional[str] = Field(None, description="My gpts name")
|
||||
file_name: Optional[str] = Field(None, description="My gpts file name")
|
||||
type: Optional[str] = Field(None, description="My gpts type")
|
||||
version: Optional[str] = Field(None, description="My gpts version")
|
||||
use_count: Optional[int] = Field(None, description="My gpts use count")
|
||||
succ_count: Optional[int] = Field(None, description="My gpts succ count")
|
||||
|
||||
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class ServerResponse(ServeRequest):
|
||||
gmt_created: Optional[str] = Field(None, description="Dbgpts create time")
|
||||
gmt_modified: Optional[str] = Field(None, description="Dbgpts upload time")
|
22
dbgpt/serve/dbgpts/my/config.py
Normal file
22
dbgpt/serve/dbgpts/my/config.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
APP_NAME = "dbgpts_my"
|
||||
SERVE_APP_NAME = "dbgpt_serve_dbgpts_my"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsMy"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_my."
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
# Database table name
|
||||
SERVER_APP_TABLE_NAME = "dbgpts_my"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
1
dbgpt/serve/dbgpts/my/dependencies.py
Normal file
1
dbgpt/serve/dbgpts/my/dependencies.py
Normal file
@@ -0,0 +1 @@
|
||||
# Define your dependencies here
|
2
dbgpt/serve/dbgpts/my/models/__init__.py
Normal file
2
dbgpt/serve/dbgpts/my/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# This is an auto-generated __init__.py file
|
||||
# generated by `dbgpt new serve dbgpts_my`
|
109
dbgpt/serve/dbgpts/my/models/models.py
Normal file
109
dbgpt/serve/dbgpts/my/models/models.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model, db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
__tablename__ = SERVER_APP_TABLE_NAME
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="gpts name")
|
||||
type = Column(String(255), nullable=False, comment="gpts type")
|
||||
version = Column(String(255), nullable=False, comment="gpts version")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
file_name = Column(String(255), nullable=True, comment="gpts package file name")
|
||||
use_count = Column(
|
||||
Integer, nullable=True, default=0, comment="gpts total use count"
|
||||
)
|
||||
succ_count = Column(
|
||||
Integer, nullable=True, default=0, comment="gpts total success count"
|
||||
)
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.utcnow, comment="gpts install time")
|
||||
gmt_modified = Column(
|
||||
DateTime,
|
||||
default=datetime.now,
|
||||
onupdate=datetime.utcnow,
|
||||
comment="Record update time",
|
||||
)
|
||||
UniqueConstraint("user_code", "name", name="uk_name")
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for MyDbgpts"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[MyGptsServeRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
request.to_dict() if isinstance(request, ServeRequest) else request
|
||||
)
|
||||
entity = ServeEntity(**request_dict)
|
||||
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: ServeEntity) -> ServeRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
return ServeRequest(
|
||||
id=entity.id,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
name=entity.name,
|
||||
file_name=entity.file_name,
|
||||
type=entity.type,
|
||||
version=entity.version,
|
||||
use_count=entity.use_count,
|
||||
succ_count=entity.succ_count,
|
||||
)
|
||||
|
||||
def to_response(self, entity: ServeEntity) -> ServerResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
gmt_created_str = (
|
||||
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_created
|
||||
else ""
|
||||
)
|
||||
gmt_modified_str = (
|
||||
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_modified
|
||||
else ""
|
||||
)
|
||||
request = self.to_request(entity)
|
||||
|
||||
return ServerResponse(
|
||||
**request.to_dict(),
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
)
|
63
dbgpt/serve/dbgpts/my/serve.py
Normal file
63
dbgpt/serve/dbgpts/my/serve.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseServe):
|
||||
"""Serve component for DB-GPT"""
|
||||
|
||||
name = SERVE_APP_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/dbgpts/my",
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
def on_init(self):
|
||||
"""Called when init the application.
|
||||
|
||||
You can do some initialization here. You can't get other components here because they may be not initialized yet
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from .models.models import ServeEntity
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
# TODO: Your code here
|
||||
self._db_manager = self.create_or_get_db_manager()
|
0
dbgpt/serve/dbgpts/my/service/__init__.py
Normal file
0
dbgpt/serve/dbgpts/my/service/__init__.py
Normal file
191
dbgpt/serve/dbgpts/my/service/service.py
Normal file
191
dbgpt/serve/dbgpts/my/service/service.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.util.dbgpts.base import INSTALL_DIR
|
||||
from dbgpt.util.dbgpts.repo import (
|
||||
copy_and_install,
|
||||
inner_copy_and_install,
|
||||
inner_uninstall,
|
||||
install,
|
||||
uninstall,
|
||||
)
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The service class for DbgptsMy"""
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def update(self, request: ServeRequest) -> ServerResponse:
|
||||
"""Update a DbgptsMy entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
|
||||
# Build the query request from the request
|
||||
query_request = {"id": request.id}
|
||||
return self.dao.update(query_request, update_request=request)
|
||||
|
||||
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
|
||||
"""Get a DbgptsMy entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_one(query_request)
|
||||
|
||||
def delete(self, request: ServeRequest) -> None:
|
||||
"""Delete a DbgptsMy entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
"""
|
||||
|
||||
# TODO: implement your own logic here
|
||||
# Build the query request from the request
|
||||
|
||||
self.dao.delete(request)
|
||||
|
||||
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
|
||||
"""Get a list of DbgptsMy entities
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
# Build the query request from the request
|
||||
query_request = request
|
||||
return self.dao.get_list(query_request)
|
||||
|
||||
def get_list_by_page(
|
||||
self, request: ServeRequest, page: int, page_size: int
|
||||
) -> PaginationResult[ServerResponse]:
|
||||
"""Get a list of DbgptsMy entities by page
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
query_request = request
|
||||
return self.dao.get_list_page(query_request, page, page_size)
|
||||
|
||||
def install_gpts(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
repo: str,
|
||||
dbgpt_path: str,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
):
|
||||
logger.info(f"install_gpts {name}")
|
||||
|
||||
# install(name, repo)
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
inner_copy_and_install(repo, name, Path(dbgpt_path))
|
||||
except Exception as e:
|
||||
logger.exception(f"install_gpts failed!{str(e)}")
|
||||
raise ValueError(f"Install dbgpts [{type}:{name}] Failed! {str(e)}", e)
|
||||
|
||||
from dbgpt.util.dbgpts.base import get_repo_path
|
||||
from dbgpt.util.dbgpts.loader import (
|
||||
BasePackage,
|
||||
InstalledPackage,
|
||||
parse_package_metadata,
|
||||
)
|
||||
|
||||
base_package: BasePackage = parse_package_metadata(
|
||||
InstalledPackage(
|
||||
name=name,
|
||||
repo=repo,
|
||||
root=dbgpt_path,
|
||||
package=type,
|
||||
)
|
||||
)
|
||||
dbgpts_entity = self.get(ServeRequest(name=name, type=type))
|
||||
|
||||
if not dbgpts_entity:
|
||||
request = ServeRequest()
|
||||
request.name = name
|
||||
|
||||
request.user_name = user_name
|
||||
request.sys_code = sys_code
|
||||
request.type = type
|
||||
request.file_name = str(INSTALL_DIR / name)
|
||||
request.version = base_package.version
|
||||
return self.create(request)
|
||||
else:
|
||||
dbgpts_entity.version = base_package.version
|
||||
|
||||
return self.update(ServeRequest(**dbgpts_entity.to_dict()))
|
||||
|
||||
def uninstall_gpts(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
):
|
||||
logger.info(f"install_gpts {name}")
|
||||
try:
|
||||
inner_uninstall(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e)
|
||||
raise ValueError(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e)
|
||||
self.delete(ServeRequest(name=name, type=type))
|
0
dbgpt/serve/dbgpts/my/tests/__init__.py
Normal file
0
dbgpt/serve/dbgpts/my/tests/__init__.py
Normal file
124
dbgpt/serve/dbgpts/my/tests/test_endpoints.py
Normal file
124
dbgpt/serve/dbgpts/my/tests/test_endpoints.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import asystem_app, client
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..api.endpoints import init_endpoints, router
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_health(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
# TODO: add your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
103
dbgpt/serve/dbgpts/my/tests/test_models.py
Normal file
103
dbgpt/serve/dbgpts/my/tests/test_models.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# import pytest
|
||||
#
|
||||
# from dbgpt.storage.metadata import db
|
||||
#
|
||||
# from ..api.schemas import ServeRequest, ServerResponse
|
||||
# from ..config import ServeConfig
|
||||
# from ..models.models import ServeDao, ServeEntity
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def setup_and_teardown():
|
||||
# db.init_db("sqlite:///:memory:")
|
||||
# db.create_all()
|
||||
#
|
||||
# yield
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def server_config():
|
||||
# # TODO : build your server config
|
||||
# return ServeConfig()
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def dao(server_config):
|
||||
# return ServeDao(server_config)
|
||||
#
|
||||
#
|
||||
# @pytest.fixture
|
||||
# def default_entity_dict():
|
||||
# # TODO: build your default entity dict
|
||||
# return {}
|
||||
#
|
||||
#
|
||||
# def test_table_exist():
|
||||
# assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
#
|
||||
#
|
||||
# def test_entity_create(default_entity_dict):
|
||||
# with db.session() as session:
|
||||
# entity = ServeEntity(**default_entity_dict)
|
||||
# session.add(entity)
|
||||
#
|
||||
#
|
||||
# def test_entity_unique_key(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_get(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_update(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_delete(default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_entity_all():
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_create(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# req = ServeRequest(**default_entity_dict)
|
||||
# res: ServerResponse = dao.create(req)
|
||||
# assert res is not None
|
||||
#
|
||||
#
|
||||
# def test_dao_get_one(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# req = ServeRequest(**default_entity_dict)
|
||||
# res: ServerResponse = dao.create(req)
|
||||
#
|
||||
#
|
||||
# def test_get_dao_get_list(dao):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_update(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_delete(dao, default_entity_dict):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def test_dao_get_list_page(dao):
|
||||
# # TODO: implement your test case
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# # Add more test cases according to your own logic
|
78
dbgpt/serve/dbgpts/my/tests/test_service.py
Normal file
78
dbgpt/serve/dbgpts/my/tests/test_service.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
from dbgpt.storage.metadata import db
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
# ...
|
||||
pass
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@@ -623,7 +623,6 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
async def _wrapper_chat_stream_flow_str(
|
||||
self, stream_iter: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
|
||||
async for output in stream_iter:
|
||||
text = output.text
|
||||
if text:
|
||||
|
@@ -15,7 +15,6 @@ from ..api.schemas import ServeRequest
|
||||
|
||||
|
||||
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
flow_name = flow.name
|
||||
flow_label = flow.label
|
||||
|
Reference in New Issue
Block a user