From a1df697cbf172b0445434b6b4a40033ddd4be0a6 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 15 Aug 2024 10:13:41 +0800 Subject: [PATCH] feat(feedback): feedback upgrade --- .../upgrade/v0_6_0/upgrade_to_v0.6.0.sql | 8 + assets/schema/upgrade/v0_6_0/v0.6.0.sql | 5 + .../initialization/serve_initialization.py | 10 + .../openapi/api_v1/feedback/feed_back_db.py | 23 +- dbgpt/serve/feedback/__init__.py | 2 + dbgpt/serve/feedback/api/__init__.py | 2 + dbgpt/serve/feedback/api/endpoints.py | 198 ++++++++++++++++++ dbgpt/serve/feedback/api/schemas.py | 65 ++++++ dbgpt/serve/feedback/config.py | 22 ++ dbgpt/serve/feedback/dependencies.py | 1 + dbgpt/serve/feedback/models/__init__.py | 2 + dbgpt/serve/feedback/models/models.py | 144 +++++++++++++ dbgpt/serve/feedback/serve.py | 63 ++++++ dbgpt/serve/feedback/service/__init__.py | 0 dbgpt/serve/feedback/service/service.py | 147 +++++++++++++ dbgpt/serve/feedback/tests/__init__.py | 0 dbgpt/serve/feedback/tests/test_endpoints.py | 124 +++++++++++ dbgpt/serve/feedback/tests/test_models.py | 103 +++++++++ dbgpt/serve/feedback/tests/test_service.py | 78 +++++++ 19 files changed, 976 insertions(+), 21 deletions(-) create mode 100644 dbgpt/serve/feedback/__init__.py create mode 100644 dbgpt/serve/feedback/api/__init__.py create mode 100644 dbgpt/serve/feedback/api/endpoints.py create mode 100644 dbgpt/serve/feedback/api/schemas.py create mode 100644 dbgpt/serve/feedback/config.py create mode 100644 dbgpt/serve/feedback/dependencies.py create mode 100644 dbgpt/serve/feedback/models/__init__.py create mode 100644 dbgpt/serve/feedback/models/models.py create mode 100644 dbgpt/serve/feedback/serve.py create mode 100644 dbgpt/serve/feedback/service/__init__.py create mode 100644 dbgpt/serve/feedback/service/service.py create mode 100644 dbgpt/serve/feedback/tests/__init__.py create mode 100644 dbgpt/serve/feedback/tests/test_endpoints.py create mode 100644 dbgpt/serve/feedback/tests/test_models.py create mode 100644 dbgpt/serve/feedback/tests/test_service.py diff --git a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql index 4fc6c963a..4f75fe7ba 100644 --- a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql +++ b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql @@ -36,6 +36,14 @@ ALTER TABLE prompt_manage ADD COLUMN `prompt_code` varchar(255) NULL COMMENT 'P ALTER TABLE prompt_manage ADD COLUMN `response_schema` text NULL COMMENT 'Prompt response schema'; ALTER TABLE prompt_manage ADD COLUMN `user_code` varchar(128) NULL COMMENT 'User code'; +--chat_feed_back +ALTER TABLE chat_feed_back ADD COLUMN `message_id` varchar(255) NULL COMMENT 'Message id'; +ALTER TABLE chat_feed_back ADD COLUMN `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike'; +ALTER TABLE chat_feed_back ADD COLUMN `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories'; +ALTER TABLE chat_feed_back ADD COLUMN `user_code` varchar(128) NULL COMMENT 'User code'; +ALTER TABLE chat_feed_back ADD COLUMN `remark` text NULL COMMENT 'Feedback remark'; + + -- dbgpt.recommend_question definition CREATE TABLE `recommend_question` ( `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', diff --git a/assets/schema/upgrade/v0_6_0/v0.6.0.sql b/assets/schema/upgrade/v0_6_0/v0.6.0.sql index 3b15a46f2..51acc8b25 100644 --- a/assets/schema/upgrade/v0_6_0/v0.6.0.sql +++ b/assets/schema/upgrade/v0_6_0/v0.6.0.sql @@ -125,6 +125,11 @@ CREATE TABLE IF NOT EXISTS `chat_feed_back` `question` longtext DEFAULT NULL COMMENT 'User question', `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `message_id` varchar(255) NULL COMMENT 'Message id', + `feedback_type` varchar(50) NULL COMMENT 'Feedback type like or unlike', + `reason_types` varchar(255) NULL COMMENT 'Feedback reason categories', + `remark` text NULL COMMENT 'Feedback remark', + `user_code` varchar(128) NULL COMMENT 'User code', `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 106da8fc9..65620bff9 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -65,3 +65,13 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(DatasourceServe) # ################################ AWEL Flow Serve Register End ######################################## + + # ################################ Chat Feedback Serve Register End ######################################## + from dbgpt.serve.feedback.serve import ( + SERVE_CONFIG_KEY_PREFIX as Feedback_SERVE_CONFIG_KEY_PREFIX, + ) + from dbgpt.serve.feedback.serve import Serve as FeedbackServe + + # Register serve feedback + system_app.register(FeedbackServe) + # ################################ Chat Feedback Register End ######################################## diff --git a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py index 2b358f8c5..6a2cda923 100644 --- a/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py +++ b/dbgpt/app/openapi/api_v1/feedback/feed_back_db.py @@ -3,29 +3,10 @@ from datetime import datetime from sqlalchemy import Column, DateTime, Integer, String, Text from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody +from dbgpt.serve.feedback.models.models import ServeEntity from dbgpt.storage.metadata import BaseDao, Model - -class ChatFeedBackEntity(Model): - __tablename__ = "chat_feed_back" - id = Column(Integer, primary_key=True) - conv_uid = Column(String(128)) - conv_index = Column(Integer) - score = Column(Integer) - ques_type = Column(String(32)) - question = Column(Text) - knowledge_space = Column(String(128)) - messages = Column(Text) - user_name = Column(String(128)) - gmt_created = Column(DateTime) - gmt_modified = Column(DateTime) - - def __repr__(self): - return ( - f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', " - f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', " - f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" - ) +ChatFeedBackEntity = ServeEntity class ChatFeedBackDao(BaseDao): diff --git a/dbgpt/serve/feedback/__init__.py b/dbgpt/serve/feedback/__init__.py new file mode 100644 index 000000000..8c61b9960 --- /dev/null +++ b/dbgpt/serve/feedback/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve feedback` diff --git a/dbgpt/serve/feedback/api/__init__.py b/dbgpt/serve/feedback/api/__init__.py new file mode 100644 index 000000000..8c61b9960 --- /dev/null +++ b/dbgpt/serve/feedback/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve feedback` diff --git a/dbgpt/serve/feedback/api/endpoints.py b/dbgpt/serve/feedback/api/endpoints.py new file mode 100644 index 000000000..7a3189eff --- /dev/null +++ b/dbgpt/serve/feedback/api/endpoints.py @@ -0,0 +1,198 @@ +import logging +from functools import cache +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ConvFeedbackReasonType, ServeRequest, ServerResponse + +router = APIRouter() + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) +logger = logging.getLogger(__name__) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/query", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def query( + request: ServeRequest, service: Service = Depends(get_service) +) -> Result[ServerResponse]: + """Query Feedback entities + + Args: + request (ServeRequest): The request + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get(request)) + + +@router.post( + "/query_page", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_page( + request: ServeRequest, + page: Optional[int] = Query(default=1, description="current page"), + page_size: Optional[int] = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query Feedback entities + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + service (Service): The service + Returns: + ServerResponse: The response + """ + return Result.succ(service.get_list_by_page(request, page, page_size)) + + +@router.post("/add") +async def add_feedback(request: ServeRequest, service: Service = Depends(get_service)): + try: + return Result.succ(service.create_or_update(request)) + except Exception as ex: + logger.exception("Create feedback error!") + return Result.failed(err_code="E000X", msg=f"create feedback error: {ex}") + + +@router.get("/list") +async def list_feedback( + conv_uid: Optional[str] = None, + feedback_type: Optional[str] = None, + service: Service = Depends(get_service), +): + try: + return Result.succ( + service.list_conv_feedbacks(conv_uid=conv_uid, feedback_type=feedback_type) + ) + except Exception as ex: + return Result.failed(err_code="E000X", msg=f"query questions error: {ex}") + + +@router.get("/reasons") +async def feedback_reasons(): + reasons = [] + for reason_type in ConvFeedbackReasonType: + reasons.append(ConvFeedbackReasonType.to_dict(reason_type)) + return Result.succ(reasons) + + +@router.post("/cancel") +async def cancel_feedback( + request: ServeRequest, service: Service = Depends(get_service) +): + try: + service.cancel_feedback(request) + return Result.succ([]) + except Exception as ex: + return Result.failed(err_code="E000X", msg=f"cancel_feedback error: {ex}") + + +@router.post("/update") +async def update_feedback( + request: ServeRequest, service: Service = Depends(get_service) +): + try: + return Result.succ(service.create_or_update(request)) + except Exception as ex: + return Result.failed(err_code="E000X", msg=f"update question error: {ex}") + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/feedback/api/schemas.py b/dbgpt/serve/feedback/api/schemas.py new file mode 100644 index 000000000..4d66292e2 --- /dev/null +++ b/dbgpt/serve/feedback/api/schemas.py @@ -0,0 +1,65 @@ +# Define your Pydantic schemas here +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + + +class ConvFeedbackReasonType(Enum): + WRONG_ANSWER = "答案有误" + WRONG_SOURCE = "来源有误" + OUTDATED_CONTENT = "内容陈旧" + UNREAL_CONTENT = "非真实数据" + ILLEGAL_CONTENT = "含色情/违法/有害信" + OTHERS = "其他" + + @classmethod + def to_dict(cls, reason_type): + return { + "reason_type": reason_type.name, + "reason": reason_type.value, + } + + @classmethod + def of_type(cls, type_name: str): + for name, member in cls.__members__.items(): + if name == type_name: + return member + raise ValueError(f"{type_name} is not a valid ConvFeedbackReasonType") + + +class ServeRequest(BaseModel): + """Feedback request model""" + + id: Optional[int] = Field(None, description="Primary Key") + gmt_created: Optional[str] = Field(None, description="Creation time") + gmt_modified: Optional[str] = Field(None, description="Modification time") + user_code: Optional[str] = Field(None, description="User ID") + user_name: Optional[str] = Field(None, description="User Name") + conv_uid: Optional[str] = Field(None, description="Conversation ID") + message_id: Optional[str] = Field( + None, description="Message ID, round_index for table chat_history_message" + ) + score: Optional[float] = Field(None, description="Rating of answers") + question: Optional[str] = Field(None, description="User question") + ques_type: Optional[str] = Field(None, description="User question type") + knowledge_space: Optional[str] = Field(None, description="Use resource") + feedback_type: Optional[str] = Field( + None, description="Feedback type like or unlike" + ) + reason_type: Optional[str] = Field(None, description="Feedback reason category") + remark: Optional[str] = Field(None, description="Remarks") + reason_types: Optional[List[str]] = Field( + default=[], description="Feedback reason categories" + ) + reason: Optional[List[Dict]] = Field( + default=[], description="Feedback reason category" + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + return model_to_dict(self) + + +ServerResponse = ServeRequest diff --git a/dbgpt/serve/feedback/config.py b/dbgpt/serve/feedback/config.py new file mode 100644 index 000000000..f4e7ba9df --- /dev/null +++ b/dbgpt/serve/feedback/config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "feedback" +SERVE_APP_NAME = "dbgpt_serve_feedback" +SERVE_APP_NAME_HUMP = "dbgpt_serve_Feedback" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.feedback." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_feedback" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) diff --git a/dbgpt/serve/feedback/dependencies.py b/dbgpt/serve/feedback/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/feedback/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/feedback/models/__init__.py b/dbgpt/serve/feedback/models/__init__.py new file mode 100644 index 000000000..8c61b9960 --- /dev/null +++ b/dbgpt/serve/feedback/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve feedback` diff --git a/dbgpt/serve/feedback/models/models.py b/dbgpt/serve/feedback/models/models.py new file mode 100644 index 000000000..f0e0f06f6 --- /dev/null +++ b/dbgpt/serve/feedback/models/models.py @@ -0,0 +1,144 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ConvFeedbackReasonType, ServeRequest, ServerResponse +from ..config import ServeConfig + + +class ServeEntity(Model): + __tablename__ = "chat_feed_back" + id = Column(Integer, primary_key=True) + conv_uid = Column(String(128)) + conv_index = Column(Integer) + score = Column(Integer) + ques_type = Column(String(32)) + question = Column(Text) + knowledge_space = Column(String(128)) + messages = Column(Text) + remark = Column(Text, nullable=True, comment="feedback remark") + message_id = Column(String(255), nullable=True, comment="Message ID") + feedback_type = Column( + String(31), nullable=True, comment="Feedback type like or unlike" + ) + reason_types = Column( + String(255), nullable=True, comment="Feedback reason categories" + ) + user_code = Column(String(255), nullable=True, comment="User ID") + + user_name = Column(String(128)) + gmt_created = Column(DateTime, default=datetime.utcnow, comment="Creation time") + gmt_modified = Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + comment="Modification time", + ) + + __table_args__ = ( + Index("idx_conv_uid", "conv_uid"), + Index("idx_gmt_create", "gmt_created"), + ) + + def __repr__(self): + return ( + f"ChatFeekBackEntity(id={self.id}, conv_index='{self.conv_index}', conv_index='{self.conv_index}', " + f"score='{self.score}', ques_type='{self.ques_type}', question='{self.question}', knowledge_space='{self.knowledge_space}', " + f"messages='{self.messages}', user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + ) + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for Feedback""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + new_dict = { + "conv_uid": request.conv_uid, + "message_id": request.message_id, + "reason_types": ",".join(request.reason_types), + "remark": request.remark, + "score": request.score, + "ques_type": request.ques_type, + "question": request.question, + "knowledge_space": request.knowledge_space, + "feedback_type": request.feedback_type, + "user_code": request.user_code, + "user_name": request.user_name, + } + entity = ServeEntity(**new_dict) + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + reason_types = [] + if entity.reason_types: + reason_types = entity.reason_types.split(",") + reason = [] + if len(reason_types) > 0: + reason = [ + ConvFeedbackReasonType.to_dict(ConvFeedbackReasonType.of_type(t)) + for t in reason_types + ] + gmt_created_str = ( + entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") + if entity.gmt_created + else None + ) + gmt_modified_str = ( + entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + if entity.gmt_modified + else None + ) + return ServeRequest( + **{ + "id": entity.id, + "user_code": entity.user_code, + "conv_uid": entity.conv_uid, + "message_id": entity.message_id, + "question": entity.question, + "knowledge_space": entity.knowledge_space, + "feedback_type": entity.feedback_type, + "remark": entity.remark, + "reason_types": reason_types, + "reason": reason, + "gmt_created": gmt_created_str, + "gmt_modified": gmt_modified_str, + } + ) + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + + return self.to_request(entity) diff --git a/dbgpt/serve/feedback/serve.py b/dbgpt/serve/feedback/serve.py new file mode 100644 index 000000000..bbecb1aa4 --- /dev/null +++ b/dbgpt/serve/feedback/serve.py @@ -0,0 +1,63 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v1/conv/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + # TODO: Your code here + self._db_manager = self.create_or_get_db_manager() diff --git a/dbgpt/serve/feedback/service/__init__.py b/dbgpt/serve/feedback/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/feedback/service/service.py b/dbgpt/serve/feedback/service/service.py new file mode 100644 index 000000000..d34fbd0da --- /dev/null +++ b/dbgpt/serve/feedback/service/service.py @@ -0,0 +1,147 @@ +from typing import List, Optional + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for Feedback""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def get(self, request: ServeRequest) -> Optional[ServerResponse]: + """Get a Feedback entity + + Args: + request (ServeRequest): The request + + Returns: + ServerResponse: The response + """ + # Build the query request from the request + query_request = request + return self.dao.get_one(query_request) + + def delete(self, request: ServeRequest) -> None: + """Delete a Feedback entity + + Args: + request (ServeRequest): The request + """ + + query_request = {"id": request.id} + self.dao.delete(query_request) + + def get_list(self, request: ServeRequest) -> List[ServerResponse]: + """Get a list of Feedback entities + + Args: + request (ServeRequest): The request + + Returns: + List[ServerResponse]: The response + """ + # TODO: implement your own logic here + # Build the query request from the request + query_request = request + return self.dao.get_list(query_request) + + def get_list_by_page( + self, request: ServeRequest, page: int, page_size: int + ) -> PaginationResult[ServerResponse]: + """Get a list of Feedback entities by page + + Args: + request (ServeRequest): The request + page (int): The page number + page_size (int): The page size + + Returns: + List[ServerResponse]: The response + """ + query_request = request + return self.dao.get_list_page(query_request, page, page_size) + + def list_conv_feedbacks( + self, + conv_uid: Optional[str] = None, + feedback_type: Optional[str] = None, + ) -> List[ServerResponse]: + feedbacks = self.dao.get_list( + ServeRequest(conv_uid=conv_uid, feedback_type=feedback_type) + ) + return [ServerResponse.from_entity(item) for item in feedbacks] + + def create_or_update(self, request: ServeRequest) -> ServerResponse: + """ + First check whether the current user has likes, and if so, check whether it's consistent with the previous likes + + If it is inconsistent, delete the previous likes and create a new like; + + if it is consistent, an error will be reported and the likes already exist. Please do not like repeatedly + """ + feedbacks = self.dao.get_list( + ServeRequest( + conv_uid=request.conv_uid, + message_id=request.message_id, + user_code=request.user_code, + ) + ) + if len(feedbacks) > 1: + raise Exception(f"current conversation message has more than one feedback.") + if len(feedbacks) == 1: + fb = feedbacks[0] + if fb.feedback_type == request.feedback_type: + raise Exception(f"Please do not repeat feedback") + self.dao.delete(ServeRequest(id=fb.id)) + + return self.dao.create(request) + + def cancel_feedback(self, request: ServeRequest) -> None: + if not (request.conv_uid and request.message_id): + raise Exception(f"cancel feedback参数缺失异常.") + + self.dao.delete( + ServeRequest( + conv_uid=request.conv_uid, + message_id=request.message_id, + ) + ) + + def delete_feedback(self, feedback_id: int) -> None: + self.dao.delete(ServeRequest(id=feedback_id)) diff --git a/dbgpt/serve/feedback/tests/__init__.py b/dbgpt/serve/feedback/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/feedback/tests/test_endpoints.py b/dbgpt/serve/feedback/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/feedback/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/feedback/tests/test_models.py b/dbgpt/serve/feedback/tests/test_models.py new file mode 100644 index 000000000..1d111644d --- /dev/null +++ b/dbgpt/serve/feedback/tests/test_models.py @@ -0,0 +1,103 @@ +import pytest + +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + with db.session() as session: + entity = ServeEntity(**default_entity_dict) + session.add(entity) + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + assert res is not None + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + req = ServeRequest(**default_entity_dict) + res: ServerResponse = dao.create(req) + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/feedback/tests/test_service.py b/dbgpt/serve/feedback/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/feedback/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic