refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve prompt`

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve prompt`

View File

@@ -0,0 +1,114 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from .schemas import ServeRequest, ServerResponse
from ..service.service import Service
from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
# TODO: Compatible with old API, will be modified in the future
@router.post("/add", response_model=Result[ServerResponse])
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new Prompt entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.post("/update", response_model=Result[ServerResponse])
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a Prompt entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post("/delete", response_model=Result[None])
async def delete(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[None]:
"""Delete a Prompt entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.delete(request))
@router.post("/list", response_model=Result[List[ServerResponse]])
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[List[ServerResponse]]:
"""Query Prompt entities
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
List[ServerResponse]: The response
"""
return Result.succ(service.get_list(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Prompt entities
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.get_list_by_page(request, page, page_size))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app

View File

@@ -0,0 +1,73 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
class ServeRequest(BaseModel):
"""Prompt request model"""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
content: Optional[str] = None
"""
The prompt content.
"""
user_name: Optional[str] = None
"""
The user name.
"""
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""
class ServerResponse(BaseModel):
"""Prompt response model"""
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
sys_code: Optional[str] = None
"""
System code
"""
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None

View File

@@ -0,0 +1,19 @@
from dataclasses import dataclass
from dbgpt.serve.core import BaseServeConfig
APP_NAME = "prompt"
SERVE_APP_NAME = "dbgpt_serve_prompt"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.prompt."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_prompt"
@dataclass
class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here

View File

@@ -0,0 +1 @@
# Define your dependencies here

View File

@@ -0,0 +1,2 @@
# This is an auto-generated __init__.py file
# generated by `dbgpt new serve prompt`

View File

@@ -0,0 +1,95 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
from typing import Union, Any, Dict
from datetime import datetime
from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint
from dbgpt.storage.metadata import Model, BaseDao, db
from ..api.schemas import ServeRequest, ServerResponse
from ..config import ServeConfig, SERVER_APP_TABLE_NAME
class ServeEntity(Model):
__tablename__ = "prompt_manage"
__table_args__ = (
UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"),
)
id = Column(Integer, primary_key=True, comment="Auto increment id")
chat_scene = Column(String(100))
sub_chat_scene = Column(String(100))
prompt_type = Column(String(100))
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Prompt"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = request.dict() if isinstance(request, ServeRequest) else request
entity = ServeEntity(**request_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return ServeRequest(
chat_scene=entity.chat_scene,
sub_chat_scene=entity.sub_chat_scene,
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
return ServerResponse(
id=entity.id,
chat_scene=entity.chat_scene,
sub_chat_scene=entity.sub_chat_scene,
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

View File

@@ -0,0 +1,36 @@
from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
class Serve(BaseComponent):
name = SERVE_APP_NAME
def __init__(
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
def init_app(self, system_app: SystemApp):
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._tags
)
init_endpoints(self._system_app)
def before_start(self):
"""Called before the start of the application.
You can do some initialization here.
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity

View File

View File

@@ -0,0 +1,117 @@
from typing import Optional, List
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.storage.metadata import BaseDao
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.serve.core import BaseService
from ..models.models import ServeDao, ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig
class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""The service class for Prompt"""
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
"""Returns the internal DAO."""
return self._dao
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Prompt entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# Build the query request from the request
query_request = {
"prompt_name": request.prompt_name,
"sys_code": request.sys_code,
}
return self.dao.update(query_request, update_request=request)
def get(self, request: ServeRequest) -> Optional[ServerResponse]:
"""Get a Prompt entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a Prompt entity
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
"prompt_name": request.prompt_name,
"sys_code": request.sys_code,
}
self.dao.delete(query_request)
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Prompt entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = request
return self.dao.get_list(query_request)
def get_list_by_page(
self, request: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get a list of Prompt entities by page
Args:
request (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)