feat(dbgpts): new dgpts serve module

This commit is contained in:
yhjun1026
2024-08-22 18:38:37 +08:00
parent 3a32344380
commit 1133ec170d
32 changed files with 1805 additions and 5 deletions

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,178 @@
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new DbgptsHub entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a DbgptsHub entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query DbgptsHub entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query DbgptsHub entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,32 @@
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""DbgptsHub request model"""
id: Optional[int] = Field(None, description="id")
name: Optional[str] = Field(..., description="Dbgpts name")
description: Optional[str] = Field(..., description="Dbgpts description")
author: Optional[str] = Field(None, description="Dbgpts author")
email: Optional[str] = Field(None, description="Dbgpts email")
type: Optional[str] = Field(None, description="Dbgpts type")
version: Optional[str] = Field(None, description="Dbgpts version")
storage_channel: Optional[str] = Field(None, description="Dbgpts storage channel")
storage_url: Optional[str] = Field(None, description="Dbgpts storage url")
download_param: Optional[str] = Field(None, description="Dbgpts download param")
installed: Optional[int] = Field(None, description="Dbgpts installed")
gmt_created: Optional[str] = Field(None, description="Dbgpts upload time")
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
ServerResponse = ServeRequest

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "dbgpts_hub"
SERVE_APP_NAME = "dbgpt_serve_dbgpts_hub"
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsHub"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_hub."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpts_hub"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_hub`

View File

@@ -0,0 +1,100 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
name = Column(String(255), unique=True, nullable=False, comment="dbgpts name")
description = Column(String(255), nullable=False, comment="dbgpts description")
author = Column(String(255), nullable=True, comment="dbgpts author")
email = Column(String(255), nullable=True, comment="dbgpts author email")
type = Column(String(255), comment="dbgpts type")
version = Column(String(255), comment="dbgpts version")
storage_channel = Column(String(255), comment="dbgpts storage channel")
storage_url = Column(String(255), comment="dbgpts download url")
download_param = Column(String(255), comment="dbgpts download param")
gmt_created = Column(
DateTime,
default=datetime.utcnow,
comment="plugin upload time",
)
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
installed = Column(Integer, default=False, comment="plugin already installed count")
UniqueConstraint("name", "type", name="uk_dbgpts")
Index("idx_q_type", "type")
def __repr__(self):
return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Dbgpts"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeEntity(**request_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return ServeRequest(
id=entity.id,
name=entity.name,
description=entity.description,
author=entity.author,
email=entity.email,
type=entity.type,
version=entity.version,
storage_channel=entity.storage_channel,
storage_url=entity.storage_url,
download_param=entity.download_param,
installed=entity.installed,
gmt_created=entity.gmt_created,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
return self.to_request(entity)

View File

@@ -0,0 +1,63 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,180 @@
import logging
from typing import List, Optional, Tuple
from dbgpt.agent import PluginStorageType
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.dbgpts.repo import _install_default_repos_if_no_repos, list_dbgpts
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for DbgptsHub"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a DbgptsHub entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a DbgptsHub entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a DbgptsHub entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of DbgptsHub entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of DbgptsHub entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
async def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
_install_default_repos_if_no_repos()
data: List[Tuple[str, str, str, str]] = await list_dbgpts()
from dbgpt.util.dbgpts.base import get_repo_path
from dbgpt.util.dbgpts.loader import (
BasePackage,
InstalledPackage,
parse_package_metadata,
)
try:
for repo, package, name, gpts_path in data:
old_hub_info = self.dao.get_one(ServeRequest(name=name, type=package))
base_package: BasePackage = parse_package_metadata(
InstalledPackage(
name=name,
repo=repo,
root=get_repo_path(repo),
package=package,
)
)
if old_hub_info:
self.dao.update(
query_request=ServeRequest(
name=old_hub_info.name, type=old_hub_info.type
),
update_request=ServeRequest(
version=base_package.version,
description=base_package.description,
),
)
else:
request = ServeRequest()
request.type = package
request.name = name
request.storage_channel = repo
request.storage_url = gpts_path
request.author = ",".join(base_package.authors)
request.email = ",".join(base_package.authors)
request.download_param = None
request.installed = 0
request.version = base_package.version
request.description = base_package.description
self.dao.create(request)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,103 @@
# import pytest
#
# from dbgpt.storage.metadata import db
#
# from ..api.schemas import ServeRequest, ServerResponse
# from ..config import ServeConfig
# from ..models.models import ServeDao, ServeEntity
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
# db.init_db("sqlite:///:memory:")
# db.create_all()
#
# yield
#
#
# @pytest.fixture
# def server_config():
# # TODO : build your server config
# return ServeConfig()
#
#
# @pytest.fixture
# def dao(server_config):
# return ServeDao(server_config)
#
#
# @pytest.fixture
# def default_entity_dict():
# # TODO: build your default entity dict
# return {}
#
#
# def test_table_exist():
# assert ServeEntity.__tablename__ in db.metadata.tables
#
#
# def test_entity_create(default_entity_dict):
# with db.session() as session:
# entity = ServeEntity(**default_entity_dict)
# session.add(entity)
#
#
# def test_entity_unique_key(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_get(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_update(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_delete(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_all():
# # TODO: implement your test case
# pass
#
#
# def test_dao_create(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
# assert res is not None
#
#
# def test_dao_get_one(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
#
#
# def test_get_dao_get_list(dao):
# # TODO: implement your test case
# pass
#
#
# def test_dao_update(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_delete(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_get_list_page(dao):
# # TODO: implement your test case
# pass
#
#
# # Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,178 @@
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new DbgptsMy entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a DbgptsMy entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query DbgptsMy entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query DbgptsMy entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,31 @@
# Define your Pydantic schemas here
from typing import Any, Dict, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""DbgptsMy request model"""
id: Optional[int] = Field(None, description="id")
user_code: Optional[str] = Field(None, description="My gpts user code")
user_name: Optional[str] = Field(None, description="My gpts user name")
sys_code: Optional[str] = Field(None, description="My gpts sys code")
name: str = Field(..., description="My gpts name")
file_name: str = Field(..., description="My gpts file name")
type: Optional[str] = Field(None, description="My gpts type")
version: Optional[str] = Field(None, description="My gpts version")
use_count: Optional[int] = Field(None, description="My gpts use count")
succ_count: Optional[int] = Field(None, description="My gpts succ count")
gmt_created: Optional[str] = Field(None, description="My gpts install time")
model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}")
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""
return model_to_dict(self, **kwargs)
ServerResponse = ServeRequest

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "dbgpts_my"
SERVE_APP_NAME = "dbgpt_serve_dbgpts_my"
SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsMy"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_my."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpts_my"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve dbgpts_my`

View File

@@ -0,0 +1,89 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="autoincrement id")
name = Column(String(255), unique=True, nullable=False, comment="gpts name")
type = Column(String(255), nullable=False, comment="gpts type")
version = Column(String(255), nullable=False, comment="gpts version")
user_code = Column(String(255), nullable=True, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
file_name = Column(String(255), nullable=True, comment="gpts package file name")
use_count = Column(
Integer, nullable=True, default=0, comment="gpts total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="gpts total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.utcnow, comment="gpts install time")
UniqueConstraint("user_code", "name", name="uk_name")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for MyDbgpts"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[MyGptsServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeRequest(**request_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return ServeRequest(
id=entity.id,
user_code=entity.user_code,
user_name=entity.user_name,
sys_code=entity.sys_code,
name=entity.name,
file_name=entity.file_name,
type=entity.type,
version=entity.version,
use_count=entity.use_count,
succ_count=entity.succ_count,
gmt_created=entity.gmt_created,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
return self.to_request(entity)

View File

@@ -0,0 +1,63 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

@@ -0,0 +1,182 @@
import logging
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.dbgpts.base import INSTALL_DIR
from dbgpt.util.dbgpts.repo import copy_and_install, install, uninstall
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
logger = logging.getLogger(__name__)
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for DbgptsMy"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a DbgptsMy entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = {"id": request.id}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a DbgptsMy entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a DbgptsMy entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of DbgptsMy entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of DbgptsMy entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
async def install_gpts(
self,
name: str,
type: str,
repo: str,
dbgpt_path: str,
user_code: Optional[str] = None,
sys_code: Optional[str] = None,
):
logger.info(f"install_gpts {name}")
# install(name, repo)
try:
copy_and_install(repo, name, dbgpt_path)
except Exception as e:
raise ValueError(f"Install dbgpts [{type}:{name}] Failed! {str(e)}", e)
from dbgpt.util.dbgpts.base import get_repo_path
from dbgpt.util.dbgpts.loader import (
BasePackage,
InstalledPackage,
parse_package_metadata,
)
base_package: BasePackage = parse_package_metadata(
InstalledPackage(
name=name,
repo=repo,
root=get_repo_path(repo),
package=type,
)
)
dbgpts_entity = self.dao.get_one(ServeRequest(name=name, type=type))
if not dbgpts_entity:
request = ServeRequest()
request.name = name
request.user_code = user_code
request.sys_code = sys_code
request.type = type
request.file_name = INSTALL_DIR / name
request.version = base_package.version
return self.create(request)
else:
dbgpts_entity.version = base_package.version
return self.update(dbgpts_entity)
async def uninstall_gpts(
self,
name: str,
type: str,
user_code: Optional[str] = None,
sys_code: Optional[str] = None,
):
logger.info(f"install_gpts {name}")
try:
uninstall(name)
except Exception as e:
raise ValueError(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e)
self.delete(ServeRequest(name=name, type=type))

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,103 @@
# import pytest
#
# from dbgpt.storage.metadata import db
#
# from ..api.schemas import ServeRequest, ServerResponse
# from ..config import ServeConfig
# from ..models.models import ServeDao, ServeEntity
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
# db.init_db("sqlite:///:memory:")
# db.create_all()
#
# yield
#
#
# @pytest.fixture
# def server_config():
# # TODO : build your server config
# return ServeConfig()
#
#
# @pytest.fixture
# def dao(server_config):
# return ServeDao(server_config)
#
#
# @pytest.fixture
# def default_entity_dict():
# # TODO: build your default entity dict
# return {}
#
#
# def test_table_exist():
# assert ServeEntity.__tablename__ in db.metadata.tables
#
#
# def test_entity_create(default_entity_dict):
# with db.session() as session:
# entity = ServeEntity(**default_entity_dict)
# session.add(entity)
#
#
# def test_entity_unique_key(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_get(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_update(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_delete(default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_entity_all():
# # TODO: implement your test case
# pass
#
#
# def test_dao_create(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
# assert res is not None
#
#
# def test_dao_get_one(dao, default_entity_dict):
# # TODO: implement your test case
# req = ServeRequest(**default_entity_dict)
# res: ServerResponse = dao.create(req)
#
#
# def test_get_dao_get_list(dao):
# # TODO: implement your test case
# pass
#
#
# def test_dao_update(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_delete(dao, default_entity_dict):
# # TODO: implement your test case
# pass
#
#
# def test_dao_get_list_page(dao):
# # TODO: implement your test case
# pass
#
#
# # Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -253,7 +253,7 @@ def _get_from_module(module, predicates: Optional[List[str]] = None):
return results return results
def _parse_package_metadata(package: InstalledPackage) -> BasePackage: def parse_package_metadata(package: InstalledPackage) -> BasePackage:
with open( with open(
Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8" Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8"
) as f: ) as f:
@@ -321,7 +321,7 @@ def _load_package_from_path(path: str):
parsed_packages = [] parsed_packages = []
for package in packages: for package in packages:
try: try:
parsed_packages.append(_parse_package_metadata(package)) parsed_packages.append(parse_package_metadata(package))
except Exception as e: except Exception as e:
logger.warning(f"Load package failed!{str(e)}", e) logger.warning(f"Load package failed!{str(e)}", e)
@@ -338,7 +338,7 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack
] ]
if not packages: if not packages:
raise ValueError(f"Can't find the package {name} or {new_name}") raise ValueError(f"Can't find the package {name} or {new_name}")
flow_package = _parse_package_metadata(packages[0]) flow_package = parse_package_metadata(packages[0])
if flow_package.package_type != "flow": if flow_package.package_type != "flow":
raise ValueError(f"Unsupported package type: {flow_package.package_type}") raise ValueError(f"Unsupported package type: {flow_package.package_type}")
return cast(FlowPackage, flow_package) return cast(FlowPackage, flow_package)

View File

@@ -209,7 +209,7 @@ def install(
if not repo_info: if not repo_info:
cl.error(f"The specified dbgpt '{name}' does not exist.", exit_code=1) cl.error(f"The specified dbgpt '{name}' does not exist.", exit_code=1)
repo, dbgpt_path = repo_info repo, dbgpt_path = repo_info
_copy_and_install(repo, name, dbgpt_path) copy_and_install(repo, name, dbgpt_path)
def uninstall(name: str): def uninstall(name: str):
@@ -227,7 +227,7 @@ def uninstall(name: str):
cl.info(f"Uninstalling dbgpt '{name}'...") cl.info(f"Uninstalling dbgpt '{name}'...")
def _copy_and_install(repo: str, name: str, package_path: Path): def copy_and_install(repo: str, name: str, package_path: Path):
if not package_path.exists(): if not package_path.exists():
cl.error( cl.error(
f"The specified dbgpt '{name}' does not exist in the {repo} tap.", f"The specified dbgpt '{name}' does not exist in the {repo} tap.",
@@ -350,6 +350,42 @@ def list_repo_apps(repo: str | None = None, with_update: bool = True):
cl.print(table) cl.print(table)
async def list_dbgpts(
spec_repo: str | None = None, with_update: bool = True
) -> List[Tuple[str, str, str, str]]:
"""scan dbgpts in repo
Args:
spec_repo: The name of the repo
Returns:
Tuple[str, Path] | None: The repo and the path of the dbgpt
"""
repos = _list_repos_details()
if spec_repo:
repos = list(filter(lambda x: x[0] == spec_repo, repos))
if not repos:
raise ValueError(f"The specified repo '{spec_repo}' does not exist.")
if with_update:
for repo in repos:
update_repo(repo[0])
data = []
for repo in repos:
repo_path = Path(repo[1])
for package in DEFAULT_PACKAGES:
dbgpt_path = repo_path / package
for app in os.listdir(dbgpt_path):
gpts_path = dbgpt_path / app
dbgpt_metadata_path = dbgpt_path / app / DBGPTS_METADATA_FILE
if (
dbgpt_path.exists()
and dbgpt_path.is_dir()
and dbgpt_metadata_path.exists()
):
data.append((repo[0], package, app, gpts_path))
return data
def list_installed_apps(): def list_installed_apps():
"""List all installed dbgpts""" """List all installed dbgpts"""
packages = _load_package_from_path(INSTALL_DIR) packages = _load_package_from_path(INSTALL_DIR)