diff --git a/dbgpt/serve/dbgpts/hub/__init__.py b/dbgpt/serve/dbgpts/hub/__init__.py new file mode 100644 index 000000000..bbd23d5fc --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_hub` diff --git a/dbgpt/serve/dbgpts/hub/api/__init__.py b/dbgpt/serve/dbgpts/hub/api/__init__.py new file mode 100644 index 000000000..bbd23d5fc --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_hub` diff --git a/dbgpt/serve/dbgpts/hub/api/endpoints.py b/dbgpt/serve/dbgpts/hub/api/endpoints.py new file mode 100644 index 000000000..ed1ea5c6f --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/api/endpoints.py @@ -0,0 +1,178 @@ +from functools import cache +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def create( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Create a new DbgptsHub entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.create(request)) + + +@router.put( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def update( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Update a DbgptsHub entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.update(request)) + + +@router.post( + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Query DbgptsHub entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get(request)) + + +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query DbgptsHub entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/dbgpts/hub/api/schemas.py b/dbgpt/serve/dbgpts/hub/api/schemas.py new file mode 100644 index 000000000..4ecf359a7 --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/api/schemas.py @@ -0,0 +1,32 @@ +# Define your Pydantic schemas here +from typing import Any, Dict, Optional + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """DbgptsHub request model""" + + id: Optional[int] = Field(None, description="id") + name: Optional[str] = Field(..., description="Dbgpts name") + description: Optional[str] = Field(..., description="Dbgpts description") + author: Optional[str] = Field(None, description="Dbgpts author") + email: Optional[str] = Field(None, description="Dbgpts email") + type: Optional[str] = Field(None, description="Dbgpts type") + version: Optional[str] = Field(None, description="Dbgpts version") + storage_channel: Optional[str] = Field(None, description="Dbgpts storage channel") + storage_url: Optional[str] = Field(None, description="Dbgpts storage url") + download_param: Optional[str] = Field(None, description="Dbgpts download param") + installed: Optional[int] = Field(None, description="Dbgpts installed") + gmt_created: Optional[str] = Field(None, description="Dbgpts upload time") + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +ServerResponse = ServeRequest diff --git a/dbgpt/serve/dbgpts/hub/config.py b/dbgpt/serve/dbgpts/hub/config.py new file mode 100644 index 000000000..1e19fe3b7 --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "dbgpts_hub" +SERVE_APP_NAME = "dbgpt_serve_dbgpts_hub" +SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsHub" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_hub." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpts_hub" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) diff --git a/dbgpt/serve/dbgpts/hub/dependencies.py b/dbgpt/serve/dbgpts/hub/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/dbgpts/hub/models/__init__.py b/dbgpt/serve/dbgpts/hub/models/__init__.py new file mode 100644 index 000000000..bbd23d5fc --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_hub` diff --git a/dbgpt/serve/dbgpts/hub/models/models.py b/dbgpt/serve/dbgpts/hub/models/models.py new file mode 100644 index 000000000..91f7f1e78 --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/models/models.py @@ -0,0 +1,100 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="Auto increment id") + + name = Column(String(255), unique=True, nullable=False, comment="dbgpts name") + description = Column(String(255), nullable=False, comment="dbgpts description") + author = Column(String(255), nullable=True, comment="dbgpts author") + email = Column(String(255), nullable=True, comment="dbgpts author email") + type = Column(String(255), comment="dbgpts type") + version = Column(String(255), comment="dbgpts version") + storage_channel = Column(String(255), comment="dbgpts storage channel") + storage_url = Column(String(255), comment="dbgpts download url") + download_param = Column(String(255), comment="dbgpts download param") + gmt_created = Column( + DateTime, + default=datetime.utcnow, + comment="plugin upload time", + ) + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + installed = Column(Integer, default=False, comment="plugin already installed count") + + UniqueConstraint("name", "type", name="uk_dbgpts") + Index("idx_q_type", "type") + + def __repr__(self): + return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for Dbgpts""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeEntity(**request_dict) + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + + return ServeRequest( + id=entity.id, + name=entity.name, + description=entity.description, + author=entity.author, + email=entity.email, + type=entity.type, + version=entity.version, + storage_channel=entity.storage_channel, + storage_url=entity.storage_url, + download_param=entity.download_param, + installed=entity.installed, + gmt_created=entity.gmt_created, + ) + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + + return self.to_request(entity) diff --git a/dbgpt/serve/dbgpts/hub/serve.py b/dbgpt/serve/dbgpts/hub/serve.py new file mode 100644 index 000000000..8c6564d3a --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/serve.py @@ -0,0 +1,63 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + # TODO: Your code here + self._db_manager = self.create_or_get_db_manager() diff --git a/dbgpt/serve/dbgpts/hub/service/__init__.py b/dbgpt/serve/dbgpts/hub/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/dbgpts/hub/service/service.py b/dbgpt/serve/dbgpts/hub/service/service.py new file mode 100644 index 000000000..170a4e70f --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/service/service.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Optional, Tuple + +from dbgpt.agent import PluginStorageType +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.dbgpts.repo import _install_default_repos_if_no_repos, list_dbgpts +from dbgpt.util.pagination_utils import PaginationResult + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for DbgptsHub""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def update(self, request: ServeRequest) -> ServerResponse: + """Update a DbgptsHub entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + return self.dao.update(query_request, update_request=request) + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a DbgptsHub entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a DbgptsHub entity + + Args: + request (ServeRequest): The request + """ + + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of DbgptsHub entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of DbgptsHub entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) + + async def refresh_hub_from_git( + self, + github_repo: str = None, + branch_name: str = "main", + authorization: str = None, + ): + logger.info("refresh_hub_by_git start!") + _install_default_repos_if_no_repos() + data: List[Tuple[str, str, str, str]] = await list_dbgpts() + + from dbgpt.util.dbgpts.base import get_repo_path + from dbgpt.util.dbgpts.loader import ( + BasePackage, + InstalledPackage, + parse_package_metadata, + ) + + try: + for repo, package, name, gpts_path in data: + old_hub_info = self.dao.get_one(ServeRequest(name=name, type=package)) + base_package: BasePackage = parse_package_metadata( + InstalledPackage( + name=name, + repo=repo, + root=get_repo_path(repo), + package=package, + ) + ) + if old_hub_info: + self.dao.update( + query_request=ServeRequest( + name=old_hub_info.name, type=old_hub_info.type + ), + update_request=ServeRequest( + version=base_package.version, + description=base_package.description, + ), + ) + else: + request = ServeRequest() + request.type = package + request.name = name + request.storage_channel = repo + request.storage_url = gpts_path + request.author = ",".join(base_package.authors) + request.email = ",".join(base_package.authors) + + request.download_param = None + request.installed = 0 + request.version = base_package.version + request.description = base_package.description + self.dao.create(request) + + except Exception as e: + raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}") diff --git a/dbgpt/serve/dbgpts/hub/tests/__init__.py b/dbgpt/serve/dbgpts/hub/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/dbgpts/hub/tests/test_endpoints.py b/dbgpt/serve/dbgpts/hub/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/dbgpts/hub/tests/test_models.py b/dbgpt/serve/dbgpts/hub/tests/test_models.py new file mode 100644 index 000000000..db4b44f8d --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/tests/test_models.py @@ -0,0 +1,103 @@ +# import pytest +# +# from dbgpt.storage.metadata import db +# +# from ..api.schemas import ServeRequest, ServerResponse +# from ..config import ServeConfig +# from ..models.models import ServeDao, ServeEntity +# +# +# @pytest.fixture(autouse=True) +# def setup_and_teardown(): +# db.init_db("sqlite:///:memory:") +# db.create_all() +# +# yield +# +# +# @pytest.fixture +# def server_config(): +# # TODO : build your server config +# return ServeConfig() +# +# +# @pytest.fixture +# def dao(server_config): +# return ServeDao(server_config) +# +# +# @pytest.fixture +# def default_entity_dict(): +# # TODO: build your default entity dict +# return {} +# +# +# def test_table_exist(): +# assert ServeEntity.__tablename__ in db.metadata.tables +# +# +# def test_entity_create(default_entity_dict): +# with db.session() as session: +# entity = ServeEntity(**default_entity_dict) +# session.add(entity) +# +# +# def test_entity_unique_key(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_get(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_update(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_delete(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_all(): +# # TODO: implement your test case +# pass +# +# +# def test_dao_create(dao, default_entity_dict): +# # TODO: implement your test case +# req = ServeRequest(**default_entity_dict) +# res: ServerResponse = dao.create(req) +# assert res is not None +# +# +# def test_dao_get_one(dao, default_entity_dict): +# # TODO: implement your test case +# req = ServeRequest(**default_entity_dict) +# res: ServerResponse = dao.create(req) +# +# +# def test_get_dao_get_list(dao): +# # TODO: implement your test case +# pass +# +# +# def test_dao_update(dao, default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_dao_delete(dao, default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_dao_get_list_page(dao): +# # TODO: implement your test case +# pass +# +# +# # Add more test cases according to your own logic diff --git a/dbgpt/serve/dbgpts/hub/tests/test_service.py b/dbgpt/serve/dbgpts/hub/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/dbgpts/hub/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/dbgpts/my/__init__.py b/dbgpt/serve/dbgpts/my/__init__.py new file mode 100644 index 000000000..2f7dde7ca --- /dev/null +++ b/dbgpt/serve/dbgpts/my/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_my` diff --git a/dbgpt/serve/dbgpts/my/api/__init__.py b/dbgpt/serve/dbgpts/my/api/__init__.py new file mode 100644 index 000000000..2f7dde7ca --- /dev/null +++ b/dbgpt/serve/dbgpts/my/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_my` diff --git a/dbgpt/serve/dbgpts/my/api/endpoints.py b/dbgpt/serve/dbgpts/my/api/endpoints.py new file mode 100644 index 000000000..17907ee98 --- /dev/null +++ b/dbgpt/serve/dbgpts/my/api/endpoints.py @@ -0,0 +1,178 @@ +from functools import cache +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def create( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Create a new DbgptsMy entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.create(request)) + + +@router.put( + "/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)] +) +async def update( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Update a DbgptsMy entity + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.update(request)) + + +@router.post( + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Query DbgptsMy entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get(request)) + + +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query DbgptsMy entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/dbgpts/my/api/schemas.py b/dbgpt/serve/dbgpts/my/api/schemas.py new file mode 100644 index 000000000..bf8cc67d4 --- /dev/null +++ b/dbgpt/serve/dbgpts/my/api/schemas.py @@ -0,0 +1,31 @@ +# Define your Pydantic schemas here +from typing import Any, Dict, Optional + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """DbgptsMy request model""" + + id: Optional[int] = Field(None, description="id") + user_code: Optional[str] = Field(None, description="My gpts user code") + user_name: Optional[str] = Field(None, description="My gpts user name") + sys_code: Optional[str] = Field(None, description="My gpts sys code") + name: str = Field(..., description="My gpts name") + file_name: str = Field(..., description="My gpts file name") + type: Optional[str] = Field(None, description="My gpts type") + version: Optional[str] = Field(None, description="My gpts version") + use_count: Optional[int] = Field(None, description="My gpts use count") + succ_count: Optional[int] = Field(None, description="My gpts succ count") + gmt_created: Optional[str] = Field(None, description="My gpts install time") + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +ServerResponse = ServeRequest diff --git a/dbgpt/serve/dbgpts/my/config.py b/dbgpt/serve/dbgpts/my/config.py new file mode 100644 index 000000000..722fe33a0 --- /dev/null +++ b/dbgpt/serve/dbgpts/my/config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "dbgpts_my" +SERVE_APP_NAME = "dbgpt_serve_dbgpts_my" +SERVE_APP_NAME_HUMP = "dbgpt_serve_DbgptsMy" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.dbgpts_my." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpts_my" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) diff --git a/dbgpt/serve/dbgpts/my/dependencies.py b/dbgpt/serve/dbgpts/my/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/dbgpts/my/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/dbgpts/my/models/__init__.py b/dbgpt/serve/dbgpts/my/models/__init__.py new file mode 100644 index 000000000..2f7dde7ca --- /dev/null +++ b/dbgpt/serve/dbgpts/my/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve dbgpts_my` diff --git a/dbgpt/serve/dbgpts/my/models/models.py b/dbgpt/serve/dbgpts/my/models/models.py new file mode 100644 index 000000000..c319c377c --- /dev/null +++ b/dbgpt/serve/dbgpts/my/models/models.py @@ -0,0 +1,89 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="autoincrement id") + name = Column(String(255), unique=True, nullable=False, comment="gpts name") + type = Column(String(255), nullable=False, comment="gpts type") + version = Column(String(255), nullable=False, comment="gpts version") + user_code = Column(String(255), nullable=True, comment="user code") + user_name = Column(String(255), nullable=True, comment="user name") + file_name = Column(String(255), nullable=True, comment="gpts package file name") + use_count = Column( + Integer, nullable=True, default=0, comment="gpts total use count" + ) + succ_count = Column( + Integer, nullable=True, default=0, comment="gpts total success count" + ) + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.utcnow, comment="gpts install time") + UniqueConstraint("user_code", "name", name="uk_name") + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for MyDbgpts""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[MyGptsServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeRequest(**request_dict) + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + return ServeRequest( + id=entity.id, + user_code=entity.user_code, + user_name=entity.user_name, + sys_code=entity.sys_code, + name=entity.name, + file_name=entity.file_name, + type=entity.type, + version=entity.version, + use_count=entity.use_count, + succ_count=entity.succ_count, + gmt_created=entity.gmt_created, + ) + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + return self.to_request(entity) diff --git a/dbgpt/serve/dbgpts/my/serve.py b/dbgpt/serve/dbgpts/my/serve.py new file mode 100644 index 000000000..8c6564d3a --- /dev/null +++ b/dbgpt/serve/dbgpts/my/serve.py @@ -0,0 +1,63 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + # TODO: Your code here + self._db_manager = self.create_or_get_db_manager() diff --git a/dbgpt/serve/dbgpts/my/service/__init__.py b/dbgpt/serve/dbgpts/my/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/dbgpts/my/service/service.py b/dbgpt/serve/dbgpts/my/service/service.py new file mode 100644 index 000000000..8907ecc01 --- /dev/null +++ b/dbgpt/serve/dbgpts/my/service/service.py @@ -0,0 +1,182 @@ +import logging +from typing import List, Optional + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.dbgpts.base import INSTALL_DIR +from dbgpt.util.dbgpts.repo import copy_and_install, install, uninstall +from dbgpt.util.pagination_utils import PaginationResult + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for DbgptsMy""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def update(self, request: ServeRequest) -> ServerResponse: + """Update a DbgptsMy entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + + # Build the query request from the request + query_request = {"id": request.id} + return self.dao.update(query_request, update_request=request) + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a DbgptsMy entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a DbgptsMy entity + + Args: + request (ServeRequest): The request + """ + + # TODO: implement your own logic here + # Build the query request from the request + query_request = { + # "id": request.id + } + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of DbgptsMy entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of DbgptsMy entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) + + async def install_gpts( + self, + name: str, + type: str, + repo: str, + dbgpt_path: str, + user_code: Optional[str] = None, + sys_code: Optional[str] = None, + ): + logger.info(f"install_gpts {name}") + + # install(name, repo) + try: + copy_and_install(repo, name, dbgpt_path) + except Exception as e: + raise ValueError(f"Install dbgpts [{type}:{name}] Failed! {str(e)}", e) + + from dbgpt.util.dbgpts.base import get_repo_path + from dbgpt.util.dbgpts.loader import ( + BasePackage, + InstalledPackage, + parse_package_metadata, + ) + + base_package: BasePackage = parse_package_metadata( + InstalledPackage( + name=name, + repo=repo, + root=get_repo_path(repo), + package=type, + ) + ) + dbgpts_entity = self.dao.get_one(ServeRequest(name=name, type=type)) + + if not dbgpts_entity: + request = ServeRequest() + request.name = name + + request.user_code = user_code + request.sys_code = sys_code + request.type = type + request.file_name = INSTALL_DIR / name + request.version = base_package.version + return self.create(request) + else: + dbgpts_entity.version = base_package.version + return self.update(dbgpts_entity) + + async def uninstall_gpts( + self, + name: str, + type: str, + user_code: Optional[str] = None, + sys_code: Optional[str] = None, + ): + logger.info(f"install_gpts {name}") + try: + uninstall(name) + except Exception as e: + raise ValueError(f"Uninstall dbgpts [{type}:{name}] Failed! {str(e)}", e) + self.delete(ServeRequest(name=name, type=type)) diff --git a/dbgpt/serve/dbgpts/my/tests/__init__.py b/dbgpt/serve/dbgpts/my/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/dbgpts/my/tests/test_endpoints.py b/dbgpt/serve/dbgpts/my/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/dbgpts/my/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/dbgpts/my/tests/test_models.py b/dbgpt/serve/dbgpts/my/tests/test_models.py new file mode 100644 index 000000000..db4b44f8d --- /dev/null +++ b/dbgpt/serve/dbgpts/my/tests/test_models.py @@ -0,0 +1,103 @@ +# import pytest +# +# from dbgpt.storage.metadata import db +# +# from ..api.schemas import ServeRequest, ServerResponse +# from ..config import ServeConfig +# from ..models.models import ServeDao, ServeEntity +# +# +# @pytest.fixture(autouse=True) +# def setup_and_teardown(): +# db.init_db("sqlite:///:memory:") +# db.create_all() +# +# yield +# +# +# @pytest.fixture +# def server_config(): +# # TODO : build your server config +# return ServeConfig() +# +# +# @pytest.fixture +# def dao(server_config): +# return ServeDao(server_config) +# +# +# @pytest.fixture +# def default_entity_dict(): +# # TODO: build your default entity dict +# return {} +# +# +# def test_table_exist(): +# assert ServeEntity.__tablename__ in db.metadata.tables +# +# +# def test_entity_create(default_entity_dict): +# with db.session() as session: +# entity = ServeEntity(**default_entity_dict) +# session.add(entity) +# +# +# def test_entity_unique_key(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_get(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_update(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_delete(default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_entity_all(): +# # TODO: implement your test case +# pass +# +# +# def test_dao_create(dao, default_entity_dict): +# # TODO: implement your test case +# req = ServeRequest(**default_entity_dict) +# res: ServerResponse = dao.create(req) +# assert res is not None +# +# +# def test_dao_get_one(dao, default_entity_dict): +# # TODO: implement your test case +# req = ServeRequest(**default_entity_dict) +# res: ServerResponse = dao.create(req) +# +# +# def test_get_dao_get_list(dao): +# # TODO: implement your test case +# pass +# +# +# def test_dao_update(dao, default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_dao_delete(dao, default_entity_dict): +# # TODO: implement your test case +# pass +# +# +# def test_dao_get_list_page(dao): +# # TODO: implement your test case +# pass +# +# +# # Add more test cases according to your own logic diff --git a/dbgpt/serve/dbgpts/my/tests/test_service.py b/dbgpt/serve/dbgpts/my/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/dbgpts/my/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/util/dbgpts/loader.py b/dbgpt/util/dbgpts/loader.py index 4695db7d7..930a7de37 100644 --- a/dbgpt/util/dbgpts/loader.py +++ b/dbgpt/util/dbgpts/loader.py @@ -253,7 +253,7 @@ def _get_from_module(module, predicates: Optional[List[str]] = None): return results -def _parse_package_metadata(package: InstalledPackage) -> BasePackage: +def parse_package_metadata(package: InstalledPackage) -> BasePackage: with open( Path(package.root) / DBGPTS_METADATA_FILE, mode="r+", encoding="utf-8" ) as f: @@ -321,7 +321,7 @@ def _load_package_from_path(path: str): parsed_packages = [] for package in packages: try: - parsed_packages.append(_parse_package_metadata(package)) + parsed_packages.append(parse_package_metadata(package)) except Exception as e: logger.warning(f"Load package failed!{str(e)}", e) @@ -338,7 +338,7 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack ] if not packages: raise ValueError(f"Can't find the package {name} or {new_name}") - flow_package = _parse_package_metadata(packages[0]) + flow_package = parse_package_metadata(packages[0]) if flow_package.package_type != "flow": raise ValueError(f"Unsupported package type: {flow_package.package_type}") return cast(FlowPackage, flow_package) diff --git a/dbgpt/util/dbgpts/repo.py b/dbgpt/util/dbgpts/repo.py index 8466b0b5d..d4318ae52 100644 --- a/dbgpt/util/dbgpts/repo.py +++ b/dbgpt/util/dbgpts/repo.py @@ -209,7 +209,7 @@ def install( if not repo_info: cl.error(f"The specified dbgpt '{name}' does not exist.", exit_code=1) repo, dbgpt_path = repo_info - _copy_and_install(repo, name, dbgpt_path) + copy_and_install(repo, name, dbgpt_path) def uninstall(name: str): @@ -227,7 +227,7 @@ def uninstall(name: str): cl.info(f"Uninstalling dbgpt '{name}'...") -def _copy_and_install(repo: str, name: str, package_path: Path): +def copy_and_install(repo: str, name: str, package_path: Path): if not package_path.exists(): cl.error( f"The specified dbgpt '{name}' does not exist in the {repo} tap.", @@ -350,6 +350,42 @@ def list_repo_apps(repo: str | None = None, with_update: bool = True): cl.print(table) +async def list_dbgpts( + spec_repo: str | None = None, with_update: bool = True +) -> List[Tuple[str, str, str, str]]: + """scan dbgpts in repo + + Args: + spec_repo: The name of the repo + + Returns: + Tuple[str, Path] | None: The repo and the path of the dbgpt + """ + repos = _list_repos_details() + if spec_repo: + repos = list(filter(lambda x: x[0] == spec_repo, repos)) + if not repos: + raise ValueError(f"The specified repo '{spec_repo}' does not exist.") + if with_update: + for repo in repos: + update_repo(repo[0]) + data = [] + for repo in repos: + repo_path = Path(repo[1]) + for package in DEFAULT_PACKAGES: + dbgpt_path = repo_path / package + for app in os.listdir(dbgpt_path): + gpts_path = dbgpt_path / app + dbgpt_metadata_path = dbgpt_path / app / DBGPTS_METADATA_FILE + if ( + dbgpt_path.exists() + and dbgpt_path.is_dir() + and dbgpt_metadata_path.exists() + ): + data.append((repo[0], package, app, gpts_path)) + return data + + def list_installed_apps(): """List all installed dbgpts""" packages = _load_package_from_path(INSTALL_DIR)