feat(prompt): prompt manage support (#599)

support prompt management
This commit is contained in:
FangYin Cheng 2023-09-19 00:24:40 +08:00 committed by GitHub
commit ef2ef0925d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 285 additions and 0 deletions

View File

@ -0,0 +1,16 @@
CREATE DATABASE prompt_management;
use prompt_management;
CREATE TABLE `prompt_manage` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名',
`gmt_created` datetime DEFAULT NULL,
`gmt_modified` datetime DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表';

View File

@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router
from pilot.openapi.api_v1.api_v1 import router as api_v1
@ -74,6 +75,7 @@ app.include_router(api_editor_route_v1, prefix="/api")
# app.include_router(api_v1)
app.include_router(knowledge_router)
app.include_router(prompt_router)
# app.include_router(api_editor_route_v1)

View File

View File

@ -0,0 +1,46 @@
from fastapi import APIRouter, File, UploadFile, Form
from pilot.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest
router = APIRouter()
prompt_manage_service = PromptManageService()
@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
print(f"/space/add params: {request}")
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt add error {e}")
@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
print(f"/prompt/list params: {request}")
try:
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt list error {e}")
@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
print(f"/prompt/update params: {request}")
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt update error {e}")
@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
except Exception as e:
return Result.faild(code="E010X", msg=f"prompt delete error {e}")

View File

@ -0,0 +1,91 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base):
__tablename__ = "prompt_manage"
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True
)
def create_prompt(self, prompt: PromptManageRequest):
session = self.Session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
prompt_type=prompt.prompt_type,
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(prompt_manage)
session.commit()
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.Session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
if query.sub_chat_scene is not None:
prompts = prompts.filter(
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
)
if query.prompt_type is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_type == query.prompt_type
)
if query.prompt_type == "private" and query.user_name is not None:
prompts = prompts.filter(
PromptManageEntity.user_name == query.user_name
)
if query.prompt_name is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
session.close()
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session()
if prompt:
session.delete(prompt)
session.commit()
session.close()

View File

View File

@ -0,0 +1,24 @@
from typing import List
from pydantic import BaseModel
class PromptManageRequest(BaseModel):
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None

View File

@ -0,0 +1,26 @@
from typing import List
from pydantic import BaseModel
class PromptQueryResponse(BaseModel):
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@ -0,0 +1,80 @@
from datetime import datetime
from pilot.server.prompt.request.request import PromptManageRequest
from pilot.server.prompt.request.response import PromptQueryResponse
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao()
class PromptManageService:
def __init__(self):
pass
"""create prompt"""
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(f"prompt name:{request.prompt_name} have already named")
prompt_manage_dao.create_prompt(request)
return True
"""get prompts"""
def get_prompts(self, request: PromptManageRequest):
query = PromptManageRequest(
chat_scene=request.chat_scene,
sub_chat_scene=request.sub_chat_scene,
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
for prompt in prompts:
res = PromptQueryResponse()
res.id = prompt.id
res.chat_scene = prompt.chat_scene
res.sub_chat_scene = prompt.sub_chat_scene
res.prompt_type = prompt.prompt_type
res.content = prompt.content
res.user_name = prompt.user_name
res.prompt_name = prompt.prompt_name
res.gmt_created = prompt.gmt_created
res.gmt_modified = prompt.gmt_modified
responses.append(res)
return responses
"""update prompt"""
def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
f"there are no or more than one space called {request.prompt_name}"
)
prompt = prompts[0]
prompt.chat_scene = request.chat_scene
prompt.sub_chat_scene = request.sub_chat_scene
prompt.prompt_type = request.prompt_type
prompt.content = request.content
prompt.user_name = request.user_name
prompt.gmt_modified = datetime.now()
return prompt_manage_dao.update_prompt(prompt)
"""delete prompt"""
def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
# delete prompt
prompt = prompts[0]
return prompt_manage_dao.delete_prompt(prompt)