feat(core): Add API authentication for serve template (#950)

This commit is contained in:
Fangyin Cheng 2023-12-19 13:41:02 +08:00 committed by GitHub
parent 6739993b94
commit a10d5f57b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1293 additions and 377 deletions

View File

@ -77,8 +77,6 @@ def mount_routers(app: FastAPI):
"""Lazy import to avoid high time cost"""
from dbgpt.app.knowledge.api import router as knowledge_router
# from dbgpt.app.prompt.api import router as prompt_router
# prompt has been removed to dbgpt.serve.prompt
from dbgpt.app.llm_manage.api import router as llm_manage_api
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
@ -93,7 +91,6 @@ def mount_routers(app: FastAPI):
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
# app.include_router(prompt_router, tags=["Prompt"])
def mount_static_files(app: FastAPI):

View File

@ -7,7 +7,6 @@ from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.storage.chat_history.chat_history_db import (

View File

@ -3,7 +3,11 @@ from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
# Replace old prompt serve
# Set config
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
# Register serve app
system_app.register(PromptServe, api_prefix="/prompt")

View File

@ -1,46 +0,0 @@
from fastapi import APIRouter
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.prompt.service import PromptManageService
from dbgpt.app.prompt.request.request import PromptManageRequest
router = APIRouter()
prompt_manage_service = PromptManageService()
@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
print(f"/prompt/add params: {request}")
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt add error {e}")
@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
print(f"/prompt/list params: {request}")
try:
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt list error {e}")
@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
print(f"/prompt/update params: {request}")
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt update error {e}")
@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt delete error {e}")

View File

@ -1,89 +0,0 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.prompt.request.request import PromptManageRequest
CFG = Config()
class PromptManageEntity(Model):
__tablename__ = "prompt_manage"
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class PromptManageDao(BaseDao):
def create_prompt(self, prompt: PromptManageRequest):
session = self.get_raw_session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
prompt_type=prompt.prompt_type,
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
sys_code=prompt.sys_code,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(prompt_manage)
session.commit()
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.get_raw_session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
if query.sub_chat_scene is not None:
prompts = prompts.filter(
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
)
if query.prompt_type is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_type == query.prompt_type
)
if query.prompt_type == "private" and query.user_name is not None:
prompts = prompts.filter(
PromptManageEntity.user_name == query.user_name
)
if query.prompt_name is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
if query.sys_code is not None:
prompts = prompts.filter(PromptManageEntity.sys_code == query.sys_code)
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
session.close()
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.get_raw_session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.get_raw_session()
if prompt:
session.delete(prompt)
session.commit()
session.close()

View File

@ -1,44 +0,0 @@
from typing import List
from dbgpt._private.pydantic import BaseModel
from typing import Optional
from dbgpt._private.pydantic import BaseModel
class PromptManageRequest(BaseModel):
"""Model for managing prompts."""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
content: Optional[str] = None
"""
The prompt content.
"""
user_name: Optional[str] = None
"""
The user name.
"""
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""

View File

@ -1,26 +0,0 @@
from typing import List
from dbgpt._private.pydantic import BaseModel
class PromptQueryResponse(BaseModel):
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@ -1,87 +0,0 @@
from datetime import datetime
from dbgpt.app.prompt.request.request import PromptManageRequest
from dbgpt.app.prompt.request.response import PromptQueryResponse
from dbgpt.app.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao()
class PromptManageService:
def __init__(self):
pass
"""create prompt"""
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
err_sys_str = ""
if query.sys_code:
query.sys_code = request.sys_code
err_sys_str = f" and sys_code: {request.sys_code}"
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(
f"prompt name: {request.prompt_name}{err_sys_str} have already named"
)
prompt_manage_dao.create_prompt(request)
return True
"""get prompts"""
def get_prompts(self, request: PromptManageRequest):
query = PromptManageRequest(
chat_scene=request.chat_scene,
sub_chat_scene=request.sub_chat_scene,
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
sys_code=request.sys_code,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
for prompt in prompts:
res = PromptQueryResponse()
res.id = prompt.id
res.chat_scene = prompt.chat_scene
res.sub_chat_scene = prompt.sub_chat_scene
res.prompt_type = prompt.prompt_type
res.content = prompt.content
res.user_name = prompt.user_name
res.prompt_name = prompt.prompt_name
res.gmt_created = prompt.gmt_created
res.gmt_modified = prompt.gmt_modified
responses.append(res)
return responses
"""update prompt"""
def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
f"there are no or more than one space called {request.prompt_name}"
)
prompt = prompts[0]
prompt.chat_scene = request.chat_scene
prompt.sub_chat_scene = request.sub_chat_scene
prompt.prompt_type = request.prompt_type
prompt.content = request.content
prompt.user_name = request.user_name
prompt.gmt_modified = datetime.now()
return prompt_manage_dao.update_prompt(prompt)
"""delete prompt"""
def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
# delete prompt
prompt = prompts[0]
return prompt_manage_dao.delete_prompt(prompt)

View File

@ -16,4 +16,6 @@ class BaseServeConfig(BaseParameters):
config_prefix (str): Configuration prefix
"""
config_dict = config.get_all_by_prefix(config_prefix)
# remove prefix
config_dict = {k[len(config_prefix) :]: v for k, v in config_dict.items()}
return cls(**config_dict)

View File

@ -0,0 +1,59 @@
import pytest
import pytest_asyncio
from typing import Dict
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.util import AppConfig
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
def create_system_app(param: Dict) -> SystemApp:
app_config = param.get("app_config", {})
if isinstance(app_config, dict):
app_config = AppConfig(configs=app_config)
elif not isinstance(app_config, AppConfig):
raise RuntimeError("app_config must be AppConfig or dict")
return SystemApp(app, app_config)
@pytest_asyncio.fixture
async def asystem_app(request):
param = getattr(request, "param", {})
return create_system_app(param)
@pytest.fixture
def system_app(request):
param = getattr(request, "param", {})
return create_system_app(param)
@pytest_asyncio.fixture
async def client(request, asystem_app: SystemApp):
param = getattr(request, "param", {})
headers = param.get("headers", {})
base_url = param.get("base_url", "http://test")
client_api_key = param.get("client_api_key")
routers = param.get("routers", [])
app_caller = param.get("app_caller")
if "api_keys" in param:
del param["api_keys"]
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
for router in routers:
app.include_router(router)
if app_caller:
app_caller(app, asystem_app)
yield client

View File

@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@ -20,14 +23,79 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
# TODO: Compatible with old API, will be modified in the future
@router.post("/add", response_model=Result[ServerResponse])
@router.post(
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@ -42,7 +110,11 @@ async def create(
return Result.succ(service.create(request))
@router.post("/update", response_model=Result[ServerResponse])
@router.post(
"/update",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@ -57,7 +129,9 @@ async def update(
return Result.succ(service.update(request))
@router.post("/delete", response_model=Result[None])
@router.post(
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
)
async def delete(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[None]:
@ -72,7 +146,11 @@ async def delete(
return Result.succ(service.delete(request))
@router.post("/list", response_model=Result[List[ServerResponse]])
@router.post(
"/list",
response_model=Result[List[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[List[ServerResponse]]:
@ -87,7 +165,11 @@ async def query(
return Result.succ(service.get_list(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@ -1,73 +1,78 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""Prompt request model"""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
chat_scene: Optional[str] = Field(
None,
description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.",
examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"],
)
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
sub_chat_scene: Optional[str] = Field(
None,
description="The sub chat scene.",
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
)
content: Optional[str] = None
"""
The prompt content.
"""
prompt_type: Optional[str] = Field(
None,
description="The prompt type, either common or private.",
examples=["common", "private"],
)
prompt_name: Optional[str] = Field(
None,
description="The prompt name.",
examples=["code_assistant", "joker", "data_analysis_expert"],
)
content: Optional[str] = Field(
None,
description="The prompt content.",
examples=[
"Write a qsort function in python",
"Tell me a joke about AI",
"You are a data analysis expert.",
],
)
user_name: Optional[str] = None
"""
The user name.
"""
user_name: Optional[str] = Field(
None,
description="The user name.",
examples=["zhangsan", "lisi", "wangwu"],
)
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""
sys_code: Optional[str] = Field(
None,
description="The system code.",
examples=["dbgpt", "auth_manager", "data_platform"],
)
class ServerResponse(BaseModel):
class ServerResponse(ServeRequest):
"""Prompt response model"""
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
sys_code: Optional[str] = None
"""
System code
"""
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None
id: Optional[int] = Field(
None,
description="The prompt id.",
examples=[1, 2, 3],
)
gmt_created: Optional[str] = Field(
None,
description="The prompt created time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
gmt_modified: Optional[str] = Field(
None,
description="The prompt modified time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
@ -17,3 +18,15 @@ class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
default_user: Optional[str] = field(
default=None,
metadata={"help": "Default user name for prompt"},
)
default_sys_code: Optional[str] = field(
default=None,
metadata={"help": "Default system code for prompt"},
)

View File

@ -2,7 +2,13 @@ from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
class Serve(BaseComponent):

View File

@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
@ -41,6 +41,22 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: ServeRequest) -> ServerResponse:
"""Create a new Prompt entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
if not request.user_name:
request.user_name = self.config.default_user
if not request.sys_code:
request.sys_code = self.config.default_sys_code
return super().create(request)
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Prompt entity

View File

View File

@ -0,0 +1,176 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
async def _create_and_validate(
client: AsyncClient, sys_code: str, content: str, expect_id: int = 1, **kwargs
):
req_json = {"sys_code": sys_code, "content": content}
req_json.update(kwargs)
response = await client.post("/add", json=req_json)
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
res_obj = ServerResponse(**data)
assert res_obj.id == expect_id
assert res_obj.sys_code == sys_code
assert res_obj.content == content
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_auth(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
await _create_and_validate(client, "test", "test")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
await _create_and_validate(client, "test", "test")
response = await client.post("/update", json={"id": 1, "content": "test2"})
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
res_obj = ServerResponse(**data)
assert res_obj.id == 1
assert res_obj.sys_code == "test"
assert res_obj.content == "test2"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
for i in range(10):
await _create_and_validate(
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
)
response = await client.post("/list", json={"sys_code": "test"})
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
assert len(data) == 10
res_obj = ServerResponse(**data[0])
assert res_obj.id == 1
assert res_obj.sys_code == "test"
assert res_obj.content == "test0"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
for i in range(10):
await _create_and_validate(
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
)
response = await client.post(
"/query_page", params={"page": 1, "page_size": 5}, json={"sys_code": "test"}
)
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
page_result: PaginationResult = PaginationResult(**data)
assert page_result.total_count == 10
assert page_result.total_pages == 2
assert page_result.page == 1
assert page_result.page_size == 5
assert len(page_result.items) == 5

View File

@ -0,0 +1,257 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
return {
"chat_scene": "chat_data",
"sub_chat_scene": "excel",
"prompt_type": "common",
"prompt_name": "my_prompt_1",
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with pytest.raises(Exception):
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_update(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.update(prompt_name="my_prompt_2")
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_2"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_delete(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
for i in range(10):
ServeEntity.create(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan",
sys_code="dbgpt",
)
entities = ServeEntity.all()
assert len(entities) == 10
for entity in entities:
assert entity.chat_scene == "chat_data"
assert entity.sub_chat_scene == "excel"
assert entity.prompt_type == "common"
assert entity.content == "Write a qsort function in python."
assert entity.user_name == "zhangsan"
assert entity.sys_code == "dbgpt"
assert entity.gmt_created is not None
assert entity.gmt_modified is not None
def test_dao_create(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_1"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_dao_get_one(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
res: ServerResponse = dao.get_one(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_1"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_get_dao_get_list(dao):
for i in range(10):
dao.create(
ServeRequest(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan" if i % 2 == 0 else "lisi",
sys_code="dbgpt",
)
)
res: List[ServerResponse] = dao.get_list({"sys_code": "dbgpt"})
assert len(res) == 10
for i, r in enumerate(res):
assert r.id == i + 1
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.prompt_name == f"my_prompt_{i}"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
assert r.sys_code == "dbgpt"
half_res: List[ServerResponse] = dao.get_list({"user_name": "zhangsan"})
assert len(half_res) == 5
def test_dao_update(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
res: ServerResponse = dao.update(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"},
ServeRequest(prompt_name="my_prompt_2"),
)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_2"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_dao_delete(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
dao.delete({"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
res: ServerResponse = dao.get_one(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
)
assert res is None
def test_dao_get_list_page(dao):
for i in range(20):
dao.create(
ServeRequest(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan" if i % 2 == 0 else "lisi",
sys_code="dbgpt",
)
)
res = dao.get_list_page({"sys_code": "dbgpt"}, page=1, page_size=8)
assert res.total_count == 20
assert res.total_pages == 3
assert res.page == 1
assert res.page_size == 8
assert len(res.items) == 8
for i, r in enumerate(res.items):
assert r.id == i + 1
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.prompt_name == f"my_prompt_{i}"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
assert r.sys_code == "dbgpt"
res_half = dao.get_list_page({"user_name": "zhangsan"}, page=2, page_size=8)
assert res_half.total_count == 10
assert res_half.total_pages == 2
assert res_half.page == 2
assert res_half.page_size == 8
assert len(res_half.items) == 2
for i, r in enumerate(res_half.items):
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan"
assert r.sys_code == "dbgpt"

View File

@ -0,0 +1,154 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
return {
"chat_scene": "chat_data",
"sub_chat_scene": "excel",
"prompt_type": "common",
"prompt_name": "my_prompt_1",
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
@pytest.mark.parametrize(
"system_app",
[
{
"app_config": {
"DEBUG": True,
"dbgpt.serve.prompt.default_user": "dbgpt",
"dbgpt.serve.prompt.default_sys_code": "dbgpt",
}
}
],
indirect=True,
)
def test_config_default_user(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.prompt.default_user") == "dbgpt"
assert service.config is not None
assert service.config.default_user == "dbgpt"
assert service.config.default_sys_code == "dbgpt"
def test_service_create(service: Service, default_entity_dict):
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_update(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_get(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_delete(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
service.delete(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
assert entity is None
def test_service_get_list(service: Service):
for i in range(3):
service.create(
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
)
entities: List[ServerResponse] = service.get_list(ServeRequest(sys_code="dbgpt"))
assert len(entities) == 3
for i, entity in enumerate(entities):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"
def test_service_get_list_by_page(service: Service):
for i in range(3):
service.create(
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
)
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
assert res is not None
assert res.total_count == 3
assert res.total_pages == 2
assert len(res.items) == 2
for i, entity in enumerate(res.items):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"

View File

@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@ -20,13 +23,78 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.post("/", response_model=Result[ServerResponse])
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@ -41,7 +109,9 @@ async def create(
return Result.succ(service.create(request))
@router.put("/", response_model=Result[ServerResponse])
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@ -56,7 +126,11 @@ async def update(
return Result.succ(service.update(request))
@router.post("/query", response_model=Result[ServerResponse])
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@ -71,7 +145,11 @@ async def query(
return Result.succ(service.get(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@ -1,5 +1,6 @@
# Define your Pydantic schemas here
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
@ -7,8 +8,13 @@ class ServeRequest(BaseModel):
# TODO define your own fields here
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
class ServerResponse(BaseModel):
"""{__template_app_name__hump__} response model"""
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@ -2,7 +2,13 @@ from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
class Serve(BaseComponent):

View File

@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property

View File

@ -0,0 +1,124 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@ -0,0 +1,109 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@ -0,0 +1,76 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str):
def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
for root, dirs, files in os.walk(src_dir):
dirs[:] = [d for d in dirs if not _should_ignore(d)]
relative_path = os.path.relpath(root, src_dir)
if relative_path == ".":
relative_path = ""
@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
os.makedirs(target_dir, exist_ok=True)
for file in files:
if _should_ignore(file):
continue
try:
with open(os.path.join(root, file), "r") as f:
content = f.read()
@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
except Exception as e:
click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}")
raise e
def _should_ignore(file_or_dir: str):
"""Return True if the given file or directory should be ignored.""" ""
ignore_patterns = [".pyc", "__pycache__"]
return any(pattern in file_or_dir for pattern in ignore_patterns)

View File

@ -153,7 +153,8 @@ class BaseDao(Generic[T, REQ, RES]):
if entry is None:
raise Exception("Invalid request")
for key, value in update_request.dict().items():
setattr(entry, key, value)
if value is not None:
setattr(entry, key, value)
session.merge(entry)
return self.get_one(self.to_request(entry))

View File

@ -104,6 +104,28 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
assert user.age == 35
def test_update_user_partial(
db: DatabaseManager, User: Type[BaseModel], user_dao, user_req
):
# Create a user
created_user_response = user_dao.create(user_req)
# Update the user
updated_req = UserRequest(name=user_req.name, password="newpassword")
updated_req.age = None
updated_user = user_dao.update(
query_request={"name": user_req.name}, update_request=updated_req
)
assert updated_user.id == created_user_response.id
assert updated_user.age == user_req.age
# Verify that the user is updated in the database
with db.session() as session:
user = session.query(User).get(created_user_response.id)
assert user.age == user_req.age
assert user.password == "newpassword"
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
# Create a user
created_user_response = user_dao.create(user_req)

View File

@ -3,15 +3,18 @@ from typing import Any, Dict, Optional
class AppConfig:
def __init__(self):
self.configs = {}
def __init__(self, configs: Optional[Dict[str, Any]] = None) -> None:
self.configs = configs or {}
def set(self, key: str, value: Any) -> None:
def set(self, key: str, value: Any, overwrite: bool = False) -> None:
"""Set config value by key
Args:
key (str): The key of config
value (Any): The value of config
overwrite (bool, optional): Whether to overwrite the value if key exists. Defaults to False.
"""
if key in self.configs and not overwrite:
raise KeyError(f"Config key {key} already exists")
self.configs[key] = value
def get(self, key, default: Optional[Any] = None) -> Any: