feat(core): Add API authentication for serve template (#950)

This commit is contained in:
Fangyin Cheng
2023-12-19 13:41:02 +08:00
committed by GitHub
parent 6739993b94
commit a10d5f57b2
34 changed files with 1293 additions and 377 deletions

View File

@@ -77,8 +77,6 @@ def mount_routers(app: FastAPI):
"""Lazy import to avoid high time cost"""
from dbgpt.app.knowledge.api import router as knowledge_router
# from dbgpt.app.prompt.api import router as prompt_router
# prompt has been removed to dbgpt.serve.prompt
from dbgpt.app.llm_manage.api import router as llm_manage_api
from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
@@ -93,7 +91,6 @@ def mount_routers(app: FastAPI):
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
# app.include_router(prompt_router, tags=["Prompt"])
def mount_static_files(app: FastAPI):

View File

@@ -7,7 +7,6 @@ from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.storage.chat_history.chat_history_db import (

View File

@@ -3,7 +3,11 @@ from dbgpt.component import SystemApp
def register_serve_apps(system_app: SystemApp):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
# Replace old prompt serve
# Set config
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
# Register serve app
system_app.register(PromptServe, api_prefix="/prompt")

View File

@@ -1,46 +0,0 @@
from fastapi import APIRouter
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.prompt.service import PromptManageService
from dbgpt.app.prompt.request.request import PromptManageRequest
router = APIRouter()
prompt_manage_service = PromptManageService()
@router.post("/prompt/add")
def prompt_add(request: PromptManageRequest):
print(f"/prompt/add params: {request}")
try:
prompt_manage_service.create_prompt(request)
return Result.succ([])
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt add error {e}")
@router.post("/prompt/list")
def prompt_list(request: PromptManageRequest):
print(f"/prompt/list params: {request}")
try:
return Result.succ(prompt_manage_service.get_prompts(request))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt list error {e}")
@router.post("/prompt/update")
def prompt_update(request: PromptManageRequest):
print(f"/prompt/update params: {request}")
try:
return Result.succ(prompt_manage_service.update_prompt(request))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt update error {e}")
@router.post("/prompt/delete")
def prompt_delete(request: PromptManageRequest):
print(f"/prompt/delete params: {request}")
try:
return Result.succ(prompt_manage_service.delete_prompt(request.prompt_name))
except Exception as e:
return Result.failed(code="E010X", msg=f"prompt delete error {e}")

View File

@@ -1,89 +0,0 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.prompt.request.request import PromptManageRequest
CFG = Config()
class PromptManageEntity(Model):
__tablename__ = "prompt_manage"
id = Column(Integer, primary_key=True)
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"PromptManageEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class PromptManageDao(BaseDao):
def create_prompt(self, prompt: PromptManageRequest):
session = self.get_raw_session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
prompt_type=prompt.prompt_type,
prompt_name=prompt.prompt_name,
content=prompt.content,
user_name=prompt.user_name,
sys_code=prompt.sys_code,
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
session.add(prompt_manage)
session.commit()
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.get_raw_session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
if query.sub_chat_scene is not None:
prompts = prompts.filter(
PromptManageEntity.sub_chat_scene == query.sub_chat_scene
)
if query.prompt_type is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_type == query.prompt_type
)
if query.prompt_type == "private" and query.user_name is not None:
prompts = prompts.filter(
PromptManageEntity.user_name == query.user_name
)
if query.prompt_name is not None:
prompts = prompts.filter(
PromptManageEntity.prompt_name == query.prompt_name
)
if query.sys_code is not None:
prompts = prompts.filter(PromptManageEntity.sys_code == query.sys_code)
prompts = prompts.order_by(PromptManageEntity.gmt_created.desc())
result = prompts.all()
session.close()
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.get_raw_session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.get_raw_session()
if prompt:
session.delete(prompt)
session.commit()
session.close()

View File

@@ -1,44 +0,0 @@
from typing import List
from dbgpt._private.pydantic import BaseModel
from typing import Optional
from dbgpt._private.pydantic import BaseModel
class PromptManageRequest(BaseModel):
"""Model for managing prompts."""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
content: Optional[str] = None
"""
The prompt content.
"""
user_name: Optional[str] = None
"""
The user name.
"""
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""

View File

@@ -1,26 +0,0 @@
from typing import List
from dbgpt._private.pydantic import BaseModel
class PromptQueryResponse(BaseModel):
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@@ -1,87 +0,0 @@
from datetime import datetime
from dbgpt.app.prompt.request.request import PromptManageRequest
from dbgpt.app.prompt.request.response import PromptQueryResponse
from dbgpt.app.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao()
class PromptManageService:
def __init__(self):
pass
"""create prompt"""
def create_prompt(self, request: PromptManageRequest):
query = PromptManageRequest(
prompt_name=request.prompt_name,
)
err_sys_str = ""
if query.sys_code:
query.sys_code = request.sys_code
err_sys_str = f" and sys_code: {request.sys_code}"
prompt_name = prompt_manage_dao.get_prompts(query)
if len(prompt_name) > 0:
raise Exception(
f"prompt name: {request.prompt_name}{err_sys_str} have already named"
)
prompt_manage_dao.create_prompt(request)
return True
"""get prompts"""
def get_prompts(self, request: PromptManageRequest):
query = PromptManageRequest(
chat_scene=request.chat_scene,
sub_chat_scene=request.sub_chat_scene,
prompt_type=request.prompt_type,
prompt_name=request.prompt_name,
user_name=request.user_name,
sys_code=request.sys_code,
)
responses = []
prompts = prompt_manage_dao.get_prompts(query)
for prompt in prompts:
res = PromptQueryResponse()
res.id = prompt.id
res.chat_scene = prompt.chat_scene
res.sub_chat_scene = prompt.sub_chat_scene
res.prompt_type = prompt.prompt_type
res.content = prompt.content
res.user_name = prompt.user_name
res.prompt_name = prompt.prompt_name
res.gmt_created = prompt.gmt_created
res.gmt_modified = prompt.gmt_modified
responses.append(res)
return responses
"""update prompt"""
def update_prompt(self, request: PromptManageRequest):
query = PromptManageEntity(prompt_name=request.prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) != 1:
raise Exception(
f"there are no or more than one space called {request.prompt_name}"
)
prompt = prompts[0]
prompt.chat_scene = request.chat_scene
prompt.sub_chat_scene = request.sub_chat_scene
prompt.prompt_type = request.prompt_type
prompt.content = request.content
prompt.user_name = request.user_name
prompt.gmt_modified = datetime.now()
return prompt_manage_dao.update_prompt(prompt)
"""delete prompt"""
def delete_prompt(self, prompt_name: str):
query = PromptManageEntity(prompt_name=prompt_name)
prompts = prompt_manage_dao.get_prompts(query)
if len(prompts) == 0:
raise Exception(f"delete error, no prompt name:{prompt_name} in database ")
# delete prompt
prompt = prompts[0]
return prompt_manage_dao.delete_prompt(prompt)