feat(feedback): feedback upgrade

This commit is contained in:
yhjun1026
2024-08-15 10:13:41 +08:00
parent 0173ed58d4
commit a1df697cbf
19 changed files with 976 additions and 21 deletions

View File

@@ -36,6 +36,14 @@ ALTER TABLE prompt_manage ADD COLUMN `prompt_code` varchar(255) NULL COMMENT 'P
ALTER TABLE prompt_manage ADD COLUMN `response_schema` text NULL COMMENT 'Prompt response schema';
ALTER TABLE prompt_manage ADD COLUMN `user_code` varchar(128) NULL COMMENT 'User code';
--chat_feed_back
ALTER TABLE chat_feed_back ADD COLUMN `message_id` varchar(255) NULL COMMENT 'Message id';
ALTER TABLE chat_feed_back ADD COLUMN `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike';
ALTER TABLE chat_feed_back ADD COLUMN `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories';
ALTER TABLE chat_feed_back ADD COLUMN `user_code` varchar(128) NULL COMMENT 'User code';
ALTER TABLE chat_feed_back ADD COLUMN `remark` text NULL COMMENT 'Feedback remark';
-- dbgpt.recommend_question definition
CREATE TABLE `recommend_question` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',

View File

@@ -125,6 +125,11 @@ CREATE TABLE IF NOT EXISTS `chat_feed_back`
`question` longtext DEFAULT NULL COMMENT 'User question',
`knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name',
`messages` longtext DEFAULT NULL COMMENT 'The details of user feedback',
`message_id` varchar(255) NULL COMMENT 'Message id',
`feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike',
`reason_types` varchar(255) NULL COMMENT 'Feedback reason categories',
`remark` text NULL COMMENT 'Feedback remark',
`user_code` varchar(128) NULL COMMENT 'User code',
`user_name` varchar(128) DEFAULT NULL COMMENT 'User name',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',

View File

@@ -65,3 +65,13 @@ def register_serve_apps(system_app: SystemApp, cfg: Config):
# Register serve app
system_app.register(DatasourceServe)
# ################################ AWEL Flow Serve Register End ########################################
# ################################ Chat Feedback Serve Register End ########################################
from dbgpt.serve.feedback.serve import (
SERVE_CONFIG_KEY_PREFIX as Feedback_SERVE_CONFIG_KEY_PREFIX,
)
from dbgpt.serve.feedback.serve import Serve as FeedbackServe
# Register serve feedback
system_app.register(FeedbackServe)
# ################################ Chat Feedback Register End ########################################

View File

@@ -3,29 +3,10 @@ from datetime import datetime
from sqlalchemy import Column, DateTime, Integer, String, Text
from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from dbgpt.serve.feedback.models.models import ServeEntity
from dbgpt.storage.metadata import BaseDao, Model
class ChatFeedBackEntity(Model):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
ChatFeedBackEntity = ServeEntity
class ChatFeedBackDao(BaseDao):

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve feedback`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve feedback`

View File

@@ -0,0 +1,198 @@
import logging
from functools import cache
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ConvFeedbackReasonType, ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
logger = logging.getLogger(__name__)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Query Feedback entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get(request))
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Feedback entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
@router.post("/add")
async def add_feedback(request: ServeRequest, service: Service = Depends(get_service)):
try:
return Result.succ(service.create_or_update(request))
except Exception as ex:
logger.exception("Create feedback error!")
return Result.failed(err_code="E000X", msg=f"create feedback error: {ex}")
@router.get("/list")
async def list_feedback(
conv_uid: Optional[str] = None,
feedback_type: Optional[str] = None,
service: Service = Depends(get_service),
):
try:
return Result.succ(
service.list_conv_feedbacks(conv_uid=conv_uid, feedback_type=feedback_type)
)
except Exception as ex:
return Result.failed(err_code="E000X", msg=f"query questions error: {ex}")
@router.get("/reasons")
async def feedback_reasons():
reasons = []
for reason_type in ConvFeedbackReasonType:
reasons.append(ConvFeedbackReasonType.to_dict(reason_type))
return Result.succ(reasons)
@router.post("/cancel")
async def cancel_feedback(
request: ServeRequest, service: Service = Depends(get_service)
):
try:
service.cancel_feedback(request)
return Result.succ([])
except Exception as ex:
return Result.failed(err_code="E000X", msg=f"cancel_feedback error: {ex}")
@router.post("/update")
async def update_feedback(
request: ServeRequest, service: Service = Depends(get_service)
):
try:
return Result.succ(service.create_or_update(request))
except Exception as ex:
return Result.failed(err_code="E000X", msg=f"update question error: {ex}")
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,65 @@
# Define your Pydantic schemas here
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
class ConvFeedbackReasonType(Enum):
WRONG_ANSWER = "答案有误"
WRONG_SOURCE = "来源有误"
OUTDATED_CONTENT = "内容陈旧"
UNREAL_CONTENT = "非真实数据"
ILLEGAL_CONTENT = "含色情/违法/有害信"
OTHERS = "其他"
@classmethod
def to_dict(cls, reason_type):
return {
"reason_type": reason_type.name,
"reason": reason_type.value,
}
@classmethod
def of_type(cls, type_name: str):
for name, member in cls.__members__.items():
if name == type_name:
return member
raise ValueError(f"{type_name} is not a valid ConvFeedbackReasonType")
class ServeRequest(BaseModel):
"""Feedback request model"""
id: Optional[int] = Field(None, description="Primary Key")
gmt_created: Optional[str] = Field(None, description="Creation time")
gmt_modified: Optional[str] = Field(None, description="Modification time")
user_code: Optional[str] = Field(None, description="User ID")
user_name: Optional[str] = Field(None, description="User Name")
conv_uid: Optional[str] = Field(None, description="Conversation ID")
message_id: Optional[str] = Field(
None, description="Message ID, round_index for table chat_history_message"
)
score: Optional[float] = Field(None, description="Rating of answers")
question: Optional[str] = Field(None, description="User question")
ques_type: Optional[str] = Field(None, description="User question type")
knowledge_space: Optional[str] = Field(None, description="Use resource")
feedback_type: Optional[str] = Field(
None, description="Feedback type like or unlike"
)
reason_type: Optional[str] = Field(None, description="Feedback reason category")
remark: Optional[str] = Field(None, description="Remarks")
reason_types: Optional[List[str]] = Field(
default=[], description="Feedback reason categories"
)
reason: Optional[List[Dict]] = Field(
default=[], description="Feedback reason category"
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self)
ServerResponse = ServeRequest

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import Optional
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "feedback"
SERVE_APP_NAME = "dbgpt_serve_feedback"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Feedback"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.feedback."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_feedback"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve feedback`

View File

@@ -0,0 +1,144 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
from dbgpt.storage.metadata import BaseDao, Model, db
from ..api.schemas import ConvFeedbackReasonType, ServeRequest, ServerResponse
from ..config import ServeConfig
class ServeEntity(Model):
__tablename__ = "chat_feed_back"
id = Column(Integer, primary_key=True)
conv_uid = Column(String(128))
conv_index = Column(Integer)
score = Column(Integer)
ques_type = Column(String(32))
question = Column(Text)
knowledge_space = Column(String(128))
messages = Column(Text)
remark = Column(Text, nullable=True, comment="feedback remark")
message_id = Column(String(255), nullable=True, comment="Message ID")
feedback_type = Column(
String(31), nullable=True, comment="Feedback type like or unlike"
)
reason_types = Column(
String(255), nullable=True, comment="Feedback reason categories"
)
user_code = Column(String(255), nullable=True, comment="User ID")
user_name = Column(String(128))
gmt_created = Column(DateTime, default=datetime.utcnow, comment="Creation time")
gmt_modified = Column(
DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
comment="Modification time",
)
__table_args__ = (
Index("idx_conv_uid", "conv_uid"),
Index("idx_gmt_create", "gmt_created"),
)
def __repr__(self):
return (
f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', "
f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', "
f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
)
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Feedback"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
new_dict = {
"conv_uid": request.conv_uid,
"message_id": request.message_id,
"reason_types": ",".join(request.reason_types),
"remark": request.remark,
"score": request.score,
"ques_type": request.ques_type,
"question": request.question,
"knowledge_space": request.knowledge_space,
"feedback_type": request.feedback_type,
"user_code": request.user_code,
"user_name": request.user_name,
}
entity = ServeEntity(**new_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
reason_types = []
if entity.reason_types:
reason_types = entity.reason_types.split(",")
reason = []
if len(reason_types) > 0:
reason = [
ConvFeedbackReasonType.to_dict(ConvFeedbackReasonType.of_type(t))
for t in reason_types
]
gmt_created_str = (
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_created
else None
)
gmt_modified_str = (
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_modified
else None
)
return ServeRequest(
**{
"id": entity.id,
"user_code": entity.user_code,
"conv_uid": entity.conv_uid,
"message_id": entity.message_id,
"question": entity.question,
"knowledge_space": entity.knowledge_space,
"feedback_type": entity.feedback_type,
"remark": entity.remark,
"reason_types": reason_types,
"reason": reason,
"gmt_created": gmt_created_str,
"gmt_modified": gmt_modified_str,
}
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
return self.to_request(entity)

View File

@@ -0,0 +1,63 @@
import logging
from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
logger = logging.getLogger(__name__)
class Serve(BaseServe):
"""Serve component for DB-GPT"""
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/conv/{APP_NAME}",
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):
"""Called when init the application.
You can do some initialization here. You can't get other components here because they may be not initialized yet
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
self._db_manager = self.create_or_get_db_manager()

View File

View File

@@ -0,0 +1,147 @@
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Feedback"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a Feedback entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a Feedback entity
Args:
request (ServeRequest): The request
"""
query_request = {"id": request.id}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Feedback entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of Feedback entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
def list_conv_feedbacks(
self,
conv_uid: Optional[str] = None,
feedback_type: Optional[str] = None,
) -> List[ServerResponse]:
feedbacks = self.dao.get_list(
ServeRequest(conv_uid=conv_uid, feedback_type=feedback_type)
)
return [ServerResponse.from_entity(item) for item in feedbacks]
def create_or_update(self, request: ServeRequest) -> ServerResponse:
"""
First check whether the current user has likes, and if so, check whether it's consistent with the previous likes
If it is inconsistent, delete the previous likes and create a new like;
if it is consistent, an error will be reported and the likes already exist. Please do not like repeatedly
"""
feedbacks = self.dao.get_list(
ServeRequest(
conv_uid=request.conv_uid,
message_id=request.message_id,
user_code=request.user_code,
)
)
if len(feedbacks) > 1:
raise Exception(f"current conversation message has more than one feedback.")
if len(feedbacks) == 1:
fb = feedbacks[0]
if fb.feedback_type == request.feedback_type:
raise Exception(f"Please do not repeat feedback")
self.dao.delete(ServeRequest(id=fb.id))
return self.dao.create(request)
def cancel_feedback(self, request: ServeRequest) -> None:
if not (request.conv_uid and request.message_id):
raise Exception(f"cancel feedback参数缺失异常.")
self.dao.delete(
ServeRequest(
conv_uid=request.conv_uid,
message_id=request.message_id,
)
)
def delete_feedback(self, feedback_id: int) -> None:
self.dao.delete(ServeRequest(id=feedback_id))

View File

View File

@@ -0,0 +1,124 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import asystem_app, client
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..api.endpoints import init_endpoints, router
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,103 @@
import pytest
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig
from ..models.models import ServeDao, ServeEntity
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
with db.session() as session:
entity = ServeEntity(**default_entity_dict)
session.add(entity)
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,78 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.serve.core.tests.conftest import system_app
from dbgpt.storage.metadata import db
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic