mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(core): Add API authentication for serve template (#950)
This commit is contained in:
parent
6739993b94
commit
a10d5f57b2
@ -77,8 +77,6 @@ def mount_routers(app: FastAPI):
|
||||
"""Lazy import to avoid high time cost"""
|
||||
from dbgpt.app.knowledge.api import router as knowledge_router
|
||||
|
||||
# from dbgpt.app.prompt.api import router as prompt_router
|
||||
# prompt has been removed to dbgpt.serve.prompt
|
||||
from dbgpt.app.llm_manage.api import router as llm_manage_api
|
||||
|
||||
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
|
||||
@ -93,7 +91,6 @@ def mount_routers(app: FastAPI):
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
# app.include_router(prompt_router, tags=["Prompt"])
|
||||
|
||||
|
||||
def mount_static_files(app: FastAPI):
|
||||
|
@ -7,7 +7,6 @@ from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
|
||||
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
|
||||
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
|
||||
|
||||
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
|
||||
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
|
||||
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
|
||||
from dbgpt.storage.chat_history.chat_history_db import (
|
||||
|
@ -3,7 +3,11 @@ from dbgpt.component import SystemApp
|
||||
|
||||
def register_serve_apps(system_app: SystemApp):
|
||||
"""Register serve apps"""
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
|
||||
|
||||
# Replace old prompt serve
|
||||
# Set config
|
||||
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
|
||||
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
|
||||
# Register serve app
|
||||
system_app.register(PromptServe, api_prefix="/prompt")
|
||||
|
@ -1,46 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.prompt.service import PromptManageService
|
||||
from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
prompt_manage_service = PromptManageService()
|
||||
|
||||
|
||||
@router.post("/prompt/add")
|
||||
def prompt_add(request: PromptManageRequest):
|
||||
print(f"/prompt/add params: {request}")
|
||||
try:
|
||||
prompt_manage_service.create_prompt(request)
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.failed(code="E010X", msg=f"prompt add error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/list")
|
||||
def prompt_list(request: PromptManageRequest):
|
||||
print(f"/prompt/list params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.get_prompts(request))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E010X", msg=f"prompt list error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/update")
|
||||
def prompt_update(request: PromptManageRequest):
|
||||
print(f"/prompt/update params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.update_prompt(request))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E010X", msg=f"prompt update error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/delete")
|
||||
def prompt_delete(request: PromptManageRequest):
|
||||
print(f"/prompt/delete params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
|
||||
except Exception as e:
|
||||
return Result.failed(code="E010X", msg=f"prompt delete error {e}")
|
@ -1,89 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PromptManageEntity(Model):
|
||||
__tablename__ = "prompt_manage"
|
||||
id = Column(Integer, primary_key=True)
|
||||
chat_scene = Column(String(100))
|
||||
sub_chat_scene = Column(String(100))
|
||||
prompt_type = Column(String(100))
|
||||
prompt_name = Column(String(512))
|
||||
content = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class PromptManageDao(BaseDao):
|
||||
def create_prompt(self, prompt: PromptManageRequest):
|
||||
session = self.get_raw_session()
|
||||
prompt_manage = PromptManageEntity(
|
||||
chat_scene=prompt.chat_scene,
|
||||
sub_chat_scene=prompt.sub_chat_scene,
|
||||
prompt_type=prompt.prompt_type,
|
||||
prompt_name=prompt.prompt_name,
|
||||
content=prompt.content,
|
||||
user_name=prompt.user_name,
|
||||
sys_code=prompt.sys_code,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(prompt_manage)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_prompts(self, query: PromptManageEntity):
|
||||
session = self.get_raw_session()
|
||||
prompts = session.query(PromptManageEntity)
|
||||
if query.chat_scene is not None:
|
||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||
if query.sub_chat_scene is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
|
||||
)
|
||||
if query.prompt_type is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.prompt_type == query.prompt_type
|
||||
)
|
||||
if query.prompt_type == "private" and query.user_name is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.user_name == query.user_name
|
||||
)
|
||||
if query.prompt_name is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.prompt_name == query.prompt_name
|
||||
)
|
||||
if query.sys_code is not None:
|
||||
prompts = prompts.filter(PromptManageEntity.sys_code == query.sys_code)
|
||||
|
||||
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
|
||||
result = prompts.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.get_raw_session()
|
||||
session.merge(prompt)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def delete_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.get_raw_session()
|
||||
if prompt:
|
||||
session.delete(prompt)
|
||||
session.commit()
|
||||
session.close()
|
@ -1,44 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptManageRequest(BaseModel):
|
||||
"""Model for managing prompts."""
|
||||
|
||||
chat_scene: Optional[str] = None
|
||||
"""
|
||||
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
|
||||
"""
|
||||
|
||||
sub_chat_scene: Optional[str] = None
|
||||
"""
|
||||
The sub chat scene.
|
||||
"""
|
||||
|
||||
prompt_type: Optional[str] = None
|
||||
"""
|
||||
The prompt type, either common or private.
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
"""
|
||||
The prompt content.
|
||||
"""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""
|
||||
The user name.
|
||||
"""
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""
|
||||
System code
|
||||
"""
|
||||
|
||||
prompt_name: Optional[str] = None
|
||||
"""
|
||||
The prompt name.
|
||||
"""
|
@ -1,26 +0,0 @@
|
||||
from typing import List
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptQueryResponse(BaseModel):
|
||||
id: int = None
|
||||
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||
|
||||
chat_scene: str = None
|
||||
|
||||
"""sub_chat_scene: sub chat scene"""
|
||||
sub_chat_scene: str = None
|
||||
|
||||
"""prompt_type: common or private"""
|
||||
prompt_type: str = None
|
||||
|
||||
"""content: prompt content"""
|
||||
content: str = None
|
||||
|
||||
"""user_name: user name"""
|
||||
user_name: str = None
|
||||
|
||||
"""prompt_name: prompt name"""
|
||||
prompt_name: str = None
|
||||
gmt_created: str = None
|
||||
gmt_modified: str = None
|
@ -1,87 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from dbgpt.app.prompt.request.request import PromptManageRequest
|
||||
from dbgpt.app.prompt.request.response import PromptQueryResponse
|
||||
from dbgpt.app.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
|
||||
|
||||
prompt_manage_dao = PromptManageDao()
|
||||
|
||||
|
||||
class PromptManageService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
"""create prompt"""
|
||||
|
||||
def create_prompt(self, request: PromptManageRequest):
|
||||
query = PromptManageRequest(
|
||||
prompt_name=request.prompt_name,
|
||||
)
|
||||
err_sys_str = ""
|
||||
if query.sys_code:
|
||||
query.sys_code = request.sys_code
|
||||
err_sys_str = f" and sys_code: {request.sys_code}"
|
||||
prompt_name = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompt_name) > 0:
|
||||
raise Exception(
|
||||
f"prompt name: {request.prompt_name}{err_sys_str} have already named"
|
||||
)
|
||||
prompt_manage_dao.create_prompt(request)
|
||||
return True
|
||||
|
||||
"""get prompts"""
|
||||
|
||||
def get_prompts(self, request: PromptManageRequest):
|
||||
query = PromptManageRequest(
|
||||
chat_scene=request.chat_scene,
|
||||
sub_chat_scene=request.sub_chat_scene,
|
||||
prompt_type=request.prompt_type,
|
||||
prompt_name=request.prompt_name,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
)
|
||||
responses = []
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
for prompt in prompts:
|
||||
res = PromptQueryResponse()
|
||||
|
||||
res.id = prompt.id
|
||||
res.chat_scene = prompt.chat_scene
|
||||
res.sub_chat_scene = prompt.sub_chat_scene
|
||||
res.prompt_type = prompt.prompt_type
|
||||
res.content = prompt.content
|
||||
res.user_name = prompt.user_name
|
||||
res.prompt_name = prompt.prompt_name
|
||||
res.gmt_created = prompt.gmt_created
|
||||
res.gmt_modified = prompt.gmt_modified
|
||||
responses.append(res)
|
||||
return responses
|
||||
|
||||
"""update prompt"""
|
||||
|
||||
def update_prompt(self, request: PromptManageRequest):
|
||||
query = PromptManageEntity(prompt_name=request.prompt_name)
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompts) != 1:
|
||||
raise Exception(
|
||||
f"there are no or more than one space called {request.prompt_name}"
|
||||
)
|
||||
prompt = prompts[0]
|
||||
prompt.chat_scene = request.chat_scene
|
||||
prompt.sub_chat_scene = request.sub_chat_scene
|
||||
prompt.prompt_type = request.prompt_type
|
||||
prompt.content = request.content
|
||||
prompt.user_name = request.user_name
|
||||
prompt.gmt_modified = datetime.now()
|
||||
return prompt_manage_dao.update_prompt(prompt)
|
||||
|
||||
"""delete prompt"""
|
||||
|
||||
def delete_prompt(self, prompt_name: str):
|
||||
query = PromptManageEntity(prompt_name=prompt_name)
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompts) == 0:
|
||||
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
|
||||
# delete prompt
|
||||
prompt = prompts[0]
|
||||
return prompt_manage_dao.delete_prompt(prompt)
|
@ -16,4 +16,6 @@ class BaseServeConfig(BaseParameters):
|
||||
config_prefix (str): Configuration prefix
|
||||
"""
|
||||
config_dict = config.get_all_by_prefix(config_prefix)
|
||||
# remove prefix
|
||||
config_dict = {k[len(config_prefix) :]: v for k, v in config_dict.items()}
|
||||
return cls(**config_dict)
|
||||
|
59
dbgpt/serve/core/tests/conftest.py
Normal file
59
dbgpt/serve/core/tests/conftest.py
Normal file
@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing import Dict
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util import AppConfig
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def create_system_app(param: Dict) -> SystemApp:
|
||||
app_config = param.get("app_config", {})
|
||||
if isinstance(app_config, dict):
|
||||
app_config = AppConfig(configs=app_config)
|
||||
elif not isinstance(app_config, AppConfig):
|
||||
raise RuntimeError("app_config must be AppConfig or dict")
|
||||
return SystemApp(app, app_config)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def asystem_app(request):
|
||||
param = getattr(request, "param", {})
|
||||
return create_system_app(param)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def system_app(request):
|
||||
param = getattr(request, "param", {})
|
||||
return create_system_app(param)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(request, asystem_app: SystemApp):
|
||||
param = getattr(request, "param", {})
|
||||
headers = param.get("headers", {})
|
||||
base_url = param.get("base_url", "http://test")
|
||||
client_api_key = param.get("client_api_key")
|
||||
routers = param.get("routers", [])
|
||||
app_caller = param.get("app_caller")
|
||||
if "api_keys" in param:
|
||||
del param["api_keys"]
|
||||
if client_api_key:
|
||||
headers["Authorization"] = "Bearer " + client_api_key
|
||||
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
if app_caller:
|
||||
app_caller(app, asystem_app)
|
||||
yield client
|
@ -1,5 +1,8 @@
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
@ -20,14 +23,79 @@ def get_service() -> Service:
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key }
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# TODO: Compatible with old API, will be modified in the future
|
||||
@router.post("/add", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@ -42,7 +110,11 @@ async def create(
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.post("/update", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/update",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@ -57,7 +129,9 @@ async def update(
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post("/delete", response_model=Result[None])
|
||||
@router.post(
|
||||
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def delete(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[None]:
|
||||
@ -72,7 +146,11 @@ async def delete(
|
||||
return Result.succ(service.delete(request))
|
||||
|
||||
|
||||
@router.post("/list", response_model=Result[List[ServerResponse]])
|
||||
@router.post(
|
||||
"/list",
|
||||
response_model=Result[List[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[List[ServerResponse]]:
|
||||
@ -87,7 +165,11 @@ async def query(
|
||||
return Result.succ(service.get_list(request))
|
||||
|
||||
|
||||
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
|
@ -1,73 +1,78 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
"""Prompt request model"""
|
||||
|
||||
chat_scene: Optional[str] = None
|
||||
"""
|
||||
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
|
||||
"""
|
||||
class Config:
|
||||
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
sub_chat_scene: Optional[str] = None
|
||||
"""
|
||||
The sub chat scene.
|
||||
"""
|
||||
chat_scene: Optional[str] = Field(
|
||||
None,
|
||||
description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.",
|
||||
examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"],
|
||||
)
|
||||
|
||||
prompt_type: Optional[str] = None
|
||||
"""
|
||||
The prompt type, either common or private.
|
||||
"""
|
||||
sub_chat_scene: Optional[str] = Field(
|
||||
None,
|
||||
description="The sub chat scene.",
|
||||
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
|
||||
)
|
||||
|
||||
content: Optional[str] = None
|
||||
"""
|
||||
The prompt content.
|
||||
"""
|
||||
prompt_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt type, either common or private.",
|
||||
examples=["common", "private"],
|
||||
)
|
||||
prompt_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt name.",
|
||||
examples=["code_assistant", "joker", "data_analysis_expert"],
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt content.",
|
||||
examples=[
|
||||
"Write a qsort function in python",
|
||||
"Tell me a joke about AI",
|
||||
"You are a data analysis expert.",
|
||||
],
|
||||
)
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""
|
||||
The user name.
|
||||
"""
|
||||
user_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The user name.",
|
||||
examples=["zhangsan", "lisi", "wangwu"],
|
||||
)
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""
|
||||
System code
|
||||
"""
|
||||
|
||||
prompt_name: Optional[str] = None
|
||||
"""
|
||||
The prompt name.
|
||||
"""
|
||||
sys_code: Optional[str] = Field(
|
||||
None,
|
||||
description="The system code.",
|
||||
examples=["dbgpt", "auth_manager", "data_platform"],
|
||||
)
|
||||
|
||||
|
||||
class ServerResponse(BaseModel):
|
||||
class ServerResponse(ServeRequest):
|
||||
"""Prompt response model"""
|
||||
|
||||
id: int = None
|
||||
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||
class Config:
|
||||
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
chat_scene: str = None
|
||||
|
||||
"""sub_chat_scene: sub chat scene"""
|
||||
sub_chat_scene: str = None
|
||||
|
||||
"""prompt_type: common or private"""
|
||||
prompt_type: str = None
|
||||
|
||||
"""content: prompt content"""
|
||||
content: str = None
|
||||
|
||||
"""user_name: user name"""
|
||||
user_name: str = None
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""
|
||||
System code
|
||||
"""
|
||||
|
||||
"""prompt_name: prompt name"""
|
||||
prompt_name: str = None
|
||||
gmt_created: str = None
|
||||
gmt_modified: str = None
|
||||
id: Optional[int] = Field(
|
||||
None,
|
||||
description="The prompt id.",
|
||||
examples=[1, 2, 3],
|
||||
)
|
||||
gmt_created: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt created time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
gmt_modified: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt modified time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
@ -17,3 +18,15 @@ class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
|
||||
default_user: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Default user name for prompt"},
|
||||
)
|
||||
default_sys_code: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Default system code for prompt"},
|
||||
)
|
||||
|
@ -2,7 +2,13 @@ from typing import List, Optional
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
|
||||
from .api.endpoints import router, init_endpoints
|
||||
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
|
||||
from .config import (
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
APP_NAME,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
|
@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = ServeDao(self._serve_config)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
@ -41,6 +41,22 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def create(self, request: ServeRequest) -> ServerResponse:
|
||||
"""Create a new Prompt entity
|
||||
|
||||
Args:
|
||||
request (ServeRequest): The request
|
||||
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
|
||||
if not request.user_name:
|
||||
request.user_name = self.config.default_user
|
||||
if not request.sys_code:
|
||||
request.sys_code = self.config.default_sys_code
|
||||
return super().create(request)
|
||||
|
||||
def update(self, request: ServeRequest) -> ServerResponse:
|
||||
"""Update a Prompt entity
|
||||
|
||||
|
0
dbgpt/serve/prompt/tests/__init__.py
Normal file
0
dbgpt/serve/prompt/tests/__init__.py
Normal file
176
dbgpt/serve/prompt/tests/test_endpoints.py
Normal file
176
dbgpt/serve/prompt/tests/test_endpoints.py
Normal file
@ -0,0 +1,176 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
async def _create_and_validate(
|
||||
client: AsyncClient, sys_code: str, content: str, expect_id: int = 1, **kwargs
|
||||
):
|
||||
req_json = {"sys_code": sys_code, "content": content}
|
||||
req_json.update(kwargs)
|
||||
response = await client.post("/add", json=req_json)
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
res_obj = ServerResponse(**data)
|
||||
assert res_obj.id == expect_id
|
||||
assert res_obj.sys_code == sys_code
|
||||
assert res_obj.content == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_auth(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
await _create_and_validate(client, "test", "test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
await _create_and_validate(client, "test", "test")
|
||||
|
||||
response = await client.post("/update", json={"id": 1, "content": "test2"})
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
res_obj = ServerResponse(**data)
|
||||
assert res_obj.id == 1
|
||||
assert res_obj.sys_code == "test"
|
||||
assert res_obj.content == "test2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
for i in range(10):
|
||||
await _create_and_validate(
|
||||
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
|
||||
)
|
||||
response = await client.post("/list", json={"sys_code": "test"})
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
assert len(data) == 10
|
||||
res_obj = ServerResponse(**data[0])
|
||||
assert res_obj.id == 1
|
||||
assert res_obj.sys_code == "test"
|
||||
assert res_obj.content == "test0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
for i in range(10):
|
||||
await _create_and_validate(
|
||||
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
|
||||
)
|
||||
response = await client.post(
|
||||
"/query_page", params={"page": 1, "page_size": 5}, json={"sys_code": "test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
page_result: PaginationResult = PaginationResult(**data)
|
||||
assert page_result.total_count == 10
|
||||
assert page_result.total_pages == 2
|
||||
assert page_result.page == 1
|
||||
assert page_result.page_size == 5
|
||||
assert len(page_result.items) == 5
|
257
dbgpt/serve/prompt/tests/test_models.py
Normal file
257
dbgpt/serve/prompt/tests/test_models.py
Normal file
@ -0,0 +1,257 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
return ServeConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dao(server_config):
|
||||
return ServeDao(server_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
return {
|
||||
"chat_scene": "chat_data",
|
||||
"sub_chat_scene": "excel",
|
||||
"prompt_type": "common",
|
||||
"prompt_name": "my_prompt_1",
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
}
|
||||
|
||||
|
||||
def test_table_exist():
|
||||
assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
ServeEntity.create(**default_entity_dict)
|
||||
with pytest.raises(Exception):
|
||||
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.update(prompt_name="my_prompt_2")
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_2"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
for i in range(10):
|
||||
ServeEntity.create(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
entities = ServeEntity.all()
|
||||
assert len(entities) == 10
|
||||
for entity in entities:
|
||||
assert entity.chat_scene == "chat_data"
|
||||
assert entity.sub_chat_scene == "excel"
|
||||
assert entity.prompt_type == "common"
|
||||
assert entity.content == "Write a qsort function in python."
|
||||
assert entity.user_name == "zhangsan"
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.gmt_created is not None
|
||||
assert entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_1"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_dao_get_one(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
res: ServerResponse = dao.get_one(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
|
||||
)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_1"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_get_dao_get_list(dao):
|
||||
for i in range(10):
|
||||
dao.create(
|
||||
ServeRequest(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan" if i % 2 == 0 else "lisi",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
)
|
||||
res: List[ServerResponse] = dao.get_list({"sys_code": "dbgpt"})
|
||||
assert len(res) == 10
|
||||
for i, r in enumerate(res):
|
||||
assert r.id == i + 1
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.prompt_name == f"my_prompt_{i}"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
|
||||
assert r.sys_code == "dbgpt"
|
||||
|
||||
half_res: List[ServerResponse] = dao.get_list({"user_name": "zhangsan"})
|
||||
assert len(half_res) == 5
|
||||
|
||||
|
||||
def test_dao_update(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
res: ServerResponse = dao.update(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"},
|
||||
ServeRequest(prompt_name="my_prompt_2"),
|
||||
)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_2"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_dao_delete(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
dao.delete({"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
res: ServerResponse = dao.get_one(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
|
||||
)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_dao_get_list_page(dao):
|
||||
for i in range(20):
|
||||
dao.create(
|
||||
ServeRequest(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan" if i % 2 == 0 else "lisi",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
)
|
||||
res = dao.get_list_page({"sys_code": "dbgpt"}, page=1, page_size=8)
|
||||
assert res.total_count == 20
|
||||
assert res.total_pages == 3
|
||||
assert res.page == 1
|
||||
assert res.page_size == 8
|
||||
assert len(res.items) == 8
|
||||
for i, r in enumerate(res.items):
|
||||
assert r.id == i + 1
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.prompt_name == f"my_prompt_{i}"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
|
||||
assert r.sys_code == "dbgpt"
|
||||
|
||||
res_half = dao.get_list_page({"user_name": "zhangsan"}, page=2, page_size=8)
|
||||
assert res_half.total_count == 10
|
||||
assert res_half.total_pages == 2
|
||||
assert res_half.page == 2
|
||||
assert res_half.page_size == 8
|
||||
assert len(res_half.items) == 2
|
||||
for i, r in enumerate(res_half.items):
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan"
|
||||
assert r.sys_code == "dbgpt"
|
154
dbgpt/serve/prompt/tests/test_service.py
Normal file
154
dbgpt/serve/prompt/tests/test_service.py
Normal file
@ -0,0 +1,154 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
return {
|
||||
"chat_scene": "chat_data",
|
||||
"sub_chat_scene": "excel",
|
||||
"prompt_type": "common",
|
||||
"prompt_name": "my_prompt_1",
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[
|
||||
{
|
||||
"app_config": {
|
||||
"DEBUG": True,
|
||||
"dbgpt.serve.prompt.default_user": "dbgpt",
|
||||
"dbgpt.serve.prompt.default_sys_code": "dbgpt",
|
||||
}
|
||||
}
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_default_user(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.prompt.default_user") == "dbgpt"
|
||||
assert service.config is not None
|
||||
assert service.config.default_user == "dbgpt"
|
||||
assert service.config.default_sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
service.delete(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||
assert entity is None
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
for i in range(3):
|
||||
service.create(
|
||||
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
|
||||
)
|
||||
entities: List[ServerResponse] = service.get_list(ServeRequest(sys_code="dbgpt"))
|
||||
assert len(entities) == 3
|
||||
for i, entity in enumerate(entities):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
for i in range(3):
|
||||
service.create(
|
||||
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
|
||||
)
|
||||
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
|
||||
assert res is not None
|
||||
assert res.total_count == 3
|
||||
assert res.total_pages == 2
|
||||
assert len(res.items) == 2
|
||||
for i, entity in enumerate(res.items):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
@ -1,5 +1,8 @@
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
@ -20,13 +23,78 @@ def get_service() -> Service:
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key }
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/", response_model=Result[ServerResponse])
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@ -41,7 +109,9 @@ async def create(
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.put("/", response_model=Result[ServerResponse])
|
||||
@router.put(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@ -56,7 +126,11 @@ async def update(
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post("/query", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@ -71,7 +145,11 @@ async def query(
|
||||
return Result.succ(service.get(request))
|
||||
|
||||
|
||||
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Define your Pydantic schemas here
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
@ -7,8 +8,13 @@ class ServeRequest(BaseModel):
|
||||
|
||||
# TODO define your own fields here
|
||||
|
||||
class Config:
|
||||
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
|
||||
class ServerResponse(BaseModel):
|
||||
"""{__template_app_name__hump__} response model"""
|
||||
|
||||
# TODO define your own fields here
|
||||
class Config:
|
||||
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
|
@ -2,7 +2,13 @@ from typing import List, Optional
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
|
||||
from .api.endpoints import router, init_endpoints
|
||||
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
|
||||
from .config import (
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
APP_NAME,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
|
@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = ServeDao(self._serve_config)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
|
@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_health(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
# TODO: add your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@ -0,0 +1,109 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
# TODO : build your server config
|
||||
return ServeConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dao(server_config):
|
||||
return ServeDao(server_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
def test_table_exist():
|
||||
assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
# TODO: implement your test case
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
# TODO: implement your test case
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_dao_get_one(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
|
||||
|
||||
def test_get_dao_get_list(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_update(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_delete(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_get_list_page(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@ -0,0 +1,76 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
# ...
|
||||
pass
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str):
|
||||
|
||||
def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
dirs[:] = [d for d in dirs if not _should_ignore(d)]
|
||||
relative_path = os.path.relpath(root, src_dir)
|
||||
if relative_path == ".":
|
||||
relative_path = ""
|
||||
@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
for file in files:
|
||||
if _should_ignore(file):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(root, file), "r") as f:
|
||||
content = f.read()
|
||||
@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
except Exception as e:
|
||||
click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}")
|
||||
raise e
|
||||
|
||||
|
||||
def _should_ignore(file_or_dir: str):
|
||||
"""Return True if the given file or directory should be ignored.""" ""
|
||||
ignore_patterns = [".pyc", "__pycache__"]
|
||||
return any(pattern in file_or_dir for pattern in ignore_patterns)
|
||||
|
@ -153,7 +153,8 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
if entry is None:
|
||||
raise Exception("Invalid request")
|
||||
for key, value in update_request.dict().items():
|
||||
setattr(entry, key, value)
|
||||
if value is not None:
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
|
||||
|
@ -104,6 +104,28 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_
|
||||
assert user.age == 35
|
||||
|
||||
|
||||
def test_update_user_partial(
|
||||
db: DatabaseManager, User: Type[BaseModel], user_dao, user_req
|
||||
):
|
||||
# Create a user
|
||||
created_user_response = user_dao.create(user_req)
|
||||
|
||||
# Update the user
|
||||
updated_req = UserRequest(name=user_req.name, password="newpassword")
|
||||
updated_req.age = None
|
||||
updated_user = user_dao.update(
|
||||
query_request={"name": user_req.name}, update_request=updated_req
|
||||
)
|
||||
assert updated_user.id == created_user_response.id
|
||||
assert updated_user.age == user_req.age
|
||||
|
||||
# Verify that the user is updated in the database
|
||||
with db.session() as session:
|
||||
user = session.query(User).get(created_user_response.id)
|
||||
assert user.age == user_req.age
|
||||
assert user.password == "newpassword"
|
||||
|
||||
|
||||
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
||||
# Create a user
|
||||
created_user_response = user_dao.create(user_req)
|
||||
|
@ -3,15 +3,18 @@ from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class AppConfig:
|
||||
def __init__(self):
|
||||
self.configs = {}
|
||||
def __init__(self, configs: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.configs = configs or {}
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
def set(self, key: str, value: Any, overwrite: bool = False) -> None:
|
||||
"""Set config value by key
|
||||
Args:
|
||||
key (str): The key of config
|
||||
value (Any): The value of config
|
||||
overwrite (bool, optional): Whether to overwrite the value if key exists. Defaults to False.
|
||||
"""
|
||||
if key in self.configs and not overwrite:
|
||||
raise KeyError(f"Config key {key} already exists")
|
||||
self.configs[key] = value
|
||||
|
||||
def get(self, key, default: Optional[Any] = None) -> Any:
|
||||
|
Loading…
Reference in New Issue
Block a user