mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 06:26:18 +00:00
commit
ef2ef0925d
16
assets/schema/prompt_management.sql
Normal file
16
assets/schema/prompt_management.sql
Normal file
@ -0,0 +1,16 @@
|
||||
CREATE DATABASE prompt_management;
|
||||
use prompt_management;
|
||||
CREATE TABLE `prompt_manage` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '场景',
|
||||
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '子场景',
|
||||
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '类型: common or private',
|
||||
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt的名字',
|
||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'prompt的内容',
|
||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '用户名',
|
||||
`gmt_created` datetime DEFAULT NULL,
|
||||
`gmt_modified` datetime DEFAULT NULL,
|
||||
PRIMARY KEY (`id`),
|
||||
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
|
||||
KEY `gmt_created_idx` (`gmt_created`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='prompt管理表';
|
@ -23,6 +23,7 @@ from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pilot.server.knowledge.api import router as knowledge_router
|
||||
from pilot.server.prompt.api import router as prompt_router
|
||||
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
@ -74,6 +75,7 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
||||
|
||||
# app.include_router(api_v1)
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(prompt_router)
|
||||
# app.include_router(api_editor_route_v1)
|
||||
|
||||
|
||||
|
0
pilot/server/prompt/__init__.py
Normal file
0
pilot/server/prompt/__init__.py
Normal file
46
pilot/server/prompt/api.py
Normal file
46
pilot/server/prompt/api.py
Normal file
@ -0,0 +1,46 @@
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
from pilot.openapi.api_view_model import Result
|
||||
from pilot.server.prompt.service import PromptManageService
|
||||
from pilot.server.prompt.request.request import PromptManageRequest
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
prompt_manage_service = PromptManageService()
|
||||
|
||||
|
||||
@router.post("/prompt/add")
|
||||
def prompt_add(request: PromptManageRequest):
|
||||
print(f"/space/add params: {request}")
|
||||
try:
|
||||
prompt_manage_service.create_prompt(request)
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E010X", msg=f"prompt add error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/list")
|
||||
def prompt_list(request: PromptManageRequest):
|
||||
print(f"/prompt/list params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.get_prompts(request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E010X", msg=f"prompt list error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/update")
|
||||
def prompt_update(request: PromptManageRequest):
|
||||
print(f"/prompt/update params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.update_prompt(request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E010X", msg=f"prompt update error {e}")
|
||||
|
||||
|
||||
@router.post("/prompt/delete")
|
||||
def prompt_delete(request: PromptManageRequest):
|
||||
print(f"/prompt/delete params: {request}")
|
||||
try:
|
||||
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E010X", msg=f"prompt delete error {e}")
|
91
pilot/server/prompt/prompt_manage_db.py
Normal file
91
pilot/server/prompt/prompt_manage_db.py
Normal file
@ -0,0 +1,91 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
||||
from pilot.server.prompt.request.request import PromptManageRequest
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PromptManageEntity(Base):
|
||||
__tablename__ = "prompt_manage"
|
||||
id = Column(Integer, primary_key=True)
|
||||
chat_scene = Column(String(100))
|
||||
sub_chat_scene = Column(String(100))
|
||||
prompt_type = Column(String(100))
|
||||
prompt_name = Column(String(512))
|
||||
content = Column(Text)
|
||||
user_name = Column(String(128))
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class PromptManageDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="prompt_management", orm_base=Base, create_not_exist_table=True
|
||||
)
|
||||
|
||||
def create_prompt(self, prompt: PromptManageRequest):
|
||||
session = self.Session()
|
||||
prompt_manage = PromptManageEntity(
|
||||
chat_scene=prompt.chat_scene,
|
||||
sub_chat_scene=prompt.sub_chat_scene,
|
||||
prompt_type=prompt.prompt_type,
|
||||
prompt_name=prompt.prompt_name,
|
||||
content=prompt.content,
|
||||
user_name=prompt.user_name,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(prompt_manage)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_prompts(self, query: PromptManageEntity):
|
||||
session = self.Session()
|
||||
prompts = session.query(PromptManageEntity)
|
||||
if query.chat_scene is not None:
|
||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||
if query.sub_chat_scene is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
|
||||
)
|
||||
if query.prompt_type is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.prompt_type == query.prompt_type
|
||||
)
|
||||
if query.prompt_type == "private" and query.user_name is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.user_name == query.user_name
|
||||
)
|
||||
if query.prompt_name is not None:
|
||||
prompts = prompts.filter(
|
||||
PromptManageEntity.prompt_name == query.prompt_name
|
||||
)
|
||||
|
||||
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
|
||||
result = prompts.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def update_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.Session()
|
||||
session.merge(prompt)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def delete_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.Session()
|
||||
if prompt:
|
||||
session.delete(prompt)
|
||||
session.commit()
|
||||
session.close()
|
0
pilot/server/prompt/request/__init__.py
Normal file
0
pilot/server/prompt/request/__init__.py
Normal file
24
pilot/server/prompt/request/request.py
Normal file
24
pilot/server/prompt/request/request.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptManageRequest(BaseModel):
|
||||
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||
|
||||
chat_scene: str = None
|
||||
|
||||
"""sub_chat_scene: sub chat scene"""
|
||||
sub_chat_scene: str = None
|
||||
|
||||
"""prompt_type: common or private"""
|
||||
prompt_type: str = None
|
||||
|
||||
"""content: prompt content"""
|
||||
content: str = None
|
||||
|
||||
"""user_name: user name"""
|
||||
user_name: str = None
|
||||
|
||||
"""prompt_name: prompt name"""
|
||||
prompt_name: str = None
|
26
pilot/server/prompt/request/response.py
Normal file
26
pilot/server/prompt/request/response.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptQueryResponse(BaseModel):
|
||||
id: int = None
|
||||
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||
|
||||
chat_scene: str = None
|
||||
|
||||
"""sub_chat_scene: sub chat scene"""
|
||||
sub_chat_scene: str = None
|
||||
|
||||
"""prompt_type: common or private"""
|
||||
prompt_type: str = None
|
||||
|
||||
"""content: prompt content"""
|
||||
content: str = None
|
||||
|
||||
"""user_name: user name"""
|
||||
user_name: str = None
|
||||
|
||||
"""prompt_name: prompt name"""
|
||||
prompt_name: str = None
|
||||
gmt_created: str = None
|
||||
gmt_modified: str = None
|
80
pilot/server/prompt/service.py
Normal file
80
pilot/server/prompt/service.py
Normal file
@ -0,0 +1,80 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pilot.server.prompt.request.request import PromptManageRequest
|
||||
from pilot.server.prompt.request.response import PromptQueryResponse
|
||||
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
|
||||
|
||||
prompt_manage_dao = PromptManageDao()
|
||||
|
||||
|
||||
class PromptManageService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
"""create prompt"""
|
||||
|
||||
def create_prompt(self, request: PromptManageRequest):
|
||||
query = PromptManageRequest(
|
||||
prompt_name=request.prompt_name,
|
||||
)
|
||||
prompt_name = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompt_name) > 0:
|
||||
raise Exception(f"prompt name:{request.prompt_name} have already named")
|
||||
prompt_manage_dao.create_prompt(request)
|
||||
return True
|
||||
|
||||
"""get prompts"""
|
||||
|
||||
def get_prompts(self, request: PromptManageRequest):
|
||||
query = PromptManageRequest(
|
||||
chat_scene=request.chat_scene,
|
||||
sub_chat_scene=request.sub_chat_scene,
|
||||
prompt_type=request.prompt_type,
|
||||
prompt_name=request.prompt_name,
|
||||
user_name=request.user_name,
|
||||
)
|
||||
responses = []
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
for prompt in prompts:
|
||||
res = PromptQueryResponse()
|
||||
|
||||
res.id = prompt.id
|
||||
res.chat_scene = prompt.chat_scene
|
||||
res.sub_chat_scene = prompt.sub_chat_scene
|
||||
res.prompt_type = prompt.prompt_type
|
||||
res.content = prompt.content
|
||||
res.user_name = prompt.user_name
|
||||
res.prompt_name = prompt.prompt_name
|
||||
res.gmt_created = prompt.gmt_created
|
||||
res.gmt_modified = prompt.gmt_modified
|
||||
responses.append(res)
|
||||
return responses
|
||||
|
||||
"""update prompt"""
|
||||
|
||||
def update_prompt(self, request: PromptManageRequest):
|
||||
query = PromptManageEntity(prompt_name=request.prompt_name)
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompts) != 1:
|
||||
raise Exception(
|
||||
f"there are no or more than one space called {request.prompt_name}"
|
||||
)
|
||||
prompt = prompts[0]
|
||||
prompt.chat_scene = request.chat_scene
|
||||
prompt.sub_chat_scene = request.sub_chat_scene
|
||||
prompt.prompt_type = request.prompt_type
|
||||
prompt.content = request.content
|
||||
prompt.user_name = request.user_name
|
||||
prompt.gmt_modified = datetime.now()
|
||||
return prompt_manage_dao.update_prompt(prompt)
|
||||
|
||||
"""delete prompt"""
|
||||
|
||||
def delete_prompt(self, prompt_name: str):
|
||||
query = PromptManageEntity(prompt_name=prompt_name)
|
||||
prompts = prompt_manage_dao.get_prompts(query)
|
||||
if len(prompts) == 0:
|
||||
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
|
||||
# delete prompt
|
||||
prompt = prompts[0]
|
||||
return prompt_manage_dao.delete_prompt(prompt)
|
Loading…
Reference in New Issue
Block a user