From ce2c25eb960b24dcb175107b5127466cdcf75436 Mon Sep 17 00:00:00 2001 From: liushaodong03 Date: Mon, 18 Sep 2023 20:11:45 +0800 Subject: [PATCH] feat: prompt manage support --- assets/schema/prompt_management.sql | 16 +++++ pilot/server/dbgpt_server.py | 2 + pilot/server/prompt/__init__.py | 0 pilot/server/prompt/api.py | 46 +++++++++++++ pilot/server/prompt/prompt_manage_db.py | 91 +++++++++++++++++++++++++ pilot/server/prompt/request/__init__.py | 0 pilot/server/prompt/request/request.py | 24 +++++++ pilot/server/prompt/request/response.py | 26 +++++++ pilot/server/prompt/service.py | 80 ++++++++++++++++++++++ 9 files changed, 285 insertions(+) create mode 100644 assets/schema/prompt_management.sql create mode 100644 pilot/server/prompt/__init__.py create mode 100644 pilot/server/prompt/api.py create mode 100644 pilot/server/prompt/prompt_manage_db.py create mode 100644 pilot/server/prompt/request/__init__.py create mode 100644 pilot/server/prompt/request/request.py create mode 100644 pilot/server/prompt/request/response.py create mode 100644 pilot/server/prompt/service.py diff --git a/assets/schema/prompt_management.sql b/assets/schema/prompt_management.sql new file mode 100644 index 000000000..b2ed6de23 --- /dev/null +++ b/assets/schema/prompt_management.sql @@ -0,0 +1,16 @@ +CREATE DATABASE prompt_management; +use prompt_management; +CREATE TABLE `prompt_manage` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private', + `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名', + `gmt_created` datetime DEFAULT NULL, + `gmt_modified` datetime DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表'; \ No newline at end of file diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index d2307f06a..2385fc777 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from pilot.server.knowledge.api import router as knowledge_router +from pilot.server.prompt.api import router as prompt_router from pilot.openapi.api_v1.api_v1 import router as api_v1 @@ -74,6 +75,7 @@ app.include_router(api_editor_route_v1, prefix="/api") # app.include_router(api_v1) app.include_router(knowledge_router) +app.include_router(prompt_router) # app.include_router(api_editor_route_v1) diff --git a/pilot/server/prompt/__init__.py b/pilot/server/prompt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/prompt/api.py b/pilot/server/prompt/api.py new file mode 100644 index 000000000..b94546891 --- /dev/null +++ b/pilot/server/prompt/api.py @@ -0,0 +1,46 @@ +from fastapi import APIRouter, File, UploadFile, Form + +from pilot.openapi.api_view_model import Result +from pilot.server.prompt.service import PromptManageService +from pilot.server.prompt.request.request import PromptManageRequest + +router = APIRouter() + +prompt_manage_service = PromptManageService() + + +@router.post("/prompt/add") +def prompt_add(request: PromptManageRequest): + print(f"/space/add params: {request}") + try: + prompt_manage_service.create_prompt(request) + return Result.succ([]) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt add error {e}") + + +@router.post("/prompt/list") +def prompt_list(request: PromptManageRequest): + print(f"/prompt/list params: {request}") + try: + return Result.succ(prompt_manage_service.get_prompts(request)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt list error {e}") + + +@router.post("/prompt/update") +def prompt_update(request: PromptManageRequest): + print(f"/prompt/update params: {request}") + try: + return Result.succ(prompt_manage_service.update_prompt(request)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt update error {e}") + + +@router.post("/prompt/delete") +def prompt_delete(request: PromptManageRequest): + print(f"/prompt/delete params: {request}") + try: + return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name)) + except Exception as e: + return Result.faild(code="E010X", msg=f"prompt delete error {e}") diff --git a/pilot/server/prompt/prompt_manage_db.py b/pilot/server/prompt/prompt_manage_db.py new file mode 100644 index 000000000..6a02e6b5c --- /dev/null +++ b/pilot/server/prompt/prompt_manage_db.py @@ -0,0 +1,91 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer, Text, String, DateTime +from sqlalchemy.ext.declarative import declarative_base + +from pilot.configs.config import Config +from pilot.connections.rdbms.base_dao import BaseDao + +from pilot.server.prompt.request.request import PromptManageRequest + +CFG = Config() +Base = declarative_base() + + +class PromptManageEntity(Base): + __tablename__ = "prompt_manage" + id = Column(Integer, primary_key=True) + chat_scene = Column(String(100)) + sub_chat_scene = Column(String(100)) + prompt_type = Column(String(100)) + prompt_name = Column(String(512)) + content = Column(Text) + user_name = Column(String(128)) + gmt_created = Column(DateTime) + gmt_modified = Column(DateTime) + + def __repr__(self): + return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + + +class PromptManageDao(BaseDao): + def __init__(self): + super().__init__( + database="prompt_management", orm_base=Base, create_not_exist_table=True + ) + + def create_prompt(self, prompt: PromptManageRequest): + session = self.Session() + prompt_manage = PromptManageEntity( + chat_scene=prompt.chat_scene, + sub_chat_scene=prompt.sub_chat_scene, + prompt_type=prompt.prompt_type, + prompt_name=prompt.prompt_name, + content=prompt.content, + user_name=prompt.user_name, + gmt_created=datetime.now(), + gmt_modified=datetime.now(), + ) + session.add(prompt_manage) + session.commit() + session.close() + + def get_prompts(self, query: PromptManageEntity): + session = self.Session() + prompts = session.query(PromptManageEntity) + if query.chat_scene is not None: + prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) + if query.sub_chat_scene is not None: + prompts = prompts.filter( + PromptManageEntity.sub_chat_scene == query.sub_chat_scene + ) + if query.prompt_type is not None: + prompts = prompts.filter( + PromptManageEntity.prompt_type == query.prompt_type + ) + if query.prompt_type == "private" and query.user_name is not None: + prompts = prompts.filter( + PromptManageEntity.user_name == query.user_name + ) + if query.prompt_name is not None: + prompts = prompts.filter( + PromptManageEntity.prompt_name == query.prompt_name + ) + + prompts = prompts.order_by(PromptManageEntity.gmt_created.desc()) + result = prompts.all() + session.close() + return result + + def update_prompt(self, prompt: PromptManageEntity): + session = self.Session() + session.merge(prompt) + session.commit() + session.close() + + def delete_prompt(self, prompt: PromptManageEntity): + session = self.Session() + if prompt: + session.delete(prompt) + session.commit() + session.close() diff --git a/pilot/server/prompt/request/__init__.py b/pilot/server/prompt/request/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/prompt/request/request.py b/pilot/server/prompt/request/request.py new file mode 100644 index 000000000..c1b0683ec --- /dev/null +++ b/pilot/server/prompt/request/request.py @@ -0,0 +1,24 @@ +from typing import List + +from pydantic import BaseModel + + +class PromptManageRequest(BaseModel): + """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" + + chat_scene: str = None + + """sub_chat_scene: sub chat scene""" + sub_chat_scene: str = None + + """prompt_type: common or private""" + prompt_type: str = None + + """content: prompt content""" + content: str = None + + """user_name: user name""" + user_name: str = None + + """prompt_name: prompt name""" + prompt_name: str = None diff --git a/pilot/server/prompt/request/response.py b/pilot/server/prompt/request/response.py new file mode 100644 index 000000000..4da05e069 --- /dev/null +++ b/pilot/server/prompt/request/response.py @@ -0,0 +1,26 @@ +from typing import List +from pydantic import BaseModel + + +class PromptQueryResponse(BaseModel): + id: int = None + """chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa""" + + chat_scene: str = None + + """sub_chat_scene: sub chat scene""" + sub_chat_scene: str = None + + """prompt_type: common or private""" + prompt_type: str = None + + """content: prompt content""" + content: str = None + + """user_name: user name""" + user_name: str = None + + """prompt_name: prompt name""" + prompt_name: str = None + gmt_created: str = None + gmt_modified: str = None diff --git a/pilot/server/prompt/service.py b/pilot/server/prompt/service.py new file mode 100644 index 000000000..c108d8b88 --- /dev/null +++ b/pilot/server/prompt/service.py @@ -0,0 +1,80 @@ +from datetime import datetime + +from pilot.server.prompt.request.request import PromptManageRequest +from pilot.server.prompt.request.response import PromptQueryResponse +from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity + +prompt_manage_dao = PromptManageDao() + + +class PromptManageService: + def __init__(self): + pass + + """create prompt""" + + def create_prompt(self, request: PromptManageRequest): + query = PromptManageRequest( + prompt_name=request.prompt_name, + ) + prompt_name = prompt_manage_dao.get_prompts(query) + if len(prompt_name) > 0: + raise Exception(f"prompt name:{request.prompt_name} have already named") + prompt_manage_dao.create_prompt(request) + return True + + """get prompts""" + + def get_prompts(self, request: PromptManageRequest): + query = PromptManageRequest( + chat_scene=request.chat_scene, + sub_chat_scene=request.sub_chat_scene, + prompt_type=request.prompt_type, + prompt_name=request.prompt_name, + user_name=request.user_name, + ) + responses = [] + prompts = prompt_manage_dao.get_prompts(query) + for prompt in prompts: + res = PromptQueryResponse() + + res.id = prompt.id + res.chat_scene = prompt.chat_scene + res.sub_chat_scene = prompt.sub_chat_scene + res.prompt_type = prompt.prompt_type + res.content = prompt.content + res.user_name = prompt.user_name + res.prompt_name = prompt.prompt_name + res.gmt_created = prompt.gmt_created + res.gmt_modified = prompt.gmt_modified + responses.append(res) + return responses + + """update prompt""" + + def update_prompt(self, request: PromptManageRequest): + query = PromptManageEntity(prompt_name=request.prompt_name) + prompts = prompt_manage_dao.get_prompts(query) + if len(prompts) != 1: + raise Exception( + f"there are no or more than one space called {request.prompt_name}" + ) + prompt = prompts[0] + prompt.chat_scene = request.chat_scene + prompt.sub_chat_scene = request.sub_chat_scene + prompt.prompt_type = request.prompt_type + prompt.content = request.content + prompt.user_name = request.user_name + prompt.gmt_modified = datetime.now() + return prompt_manage_dao.update_prompt(prompt) + + """delete prompt""" + + def delete_prompt(self, prompt_name: str): + query = PromptManageEntity(prompt_name=prompt_name) + prompts = prompt_manage_dao.get_prompts(query) + if len(prompts) == 0: + raise Exception(f"delete error, no prompt name:{prompt_name} in database ") + # delete prompt + prompt = prompts[0] + return prompt_manage_dao.delete_prompt(prompt)