feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -43,6 +43,15 @@ class ServeRequest(BaseModel):
"You are a data analysis expert.",
],
)
prompt_desc: Optional[str] = Field(
None,
description="The prompt description.",
examples=[
"This is a prompt for code assistant.",
"This is a prompt for joker.",
"This is a prompt for data analysis expert.",
],
)
user_name: Optional[str] = Field(
None,

View File

@@ -48,6 +48,7 @@ class ServeEntity(Model):
default="f-string",
comment="Prompt format(eg: f-string, jinja2)",
)
prompt_desc = Column(String(512), nullable=True, comment="Prompt description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -96,6 +97,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -119,6 +121,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
prompt_type=entity.prompt_type,
prompt_name=entity.prompt_name,
content=entity.content,
prompt_desc=entity.prompt_desc,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,

View File

@@ -3,10 +3,11 @@ from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.component import SystemApp
from dbgpt.core import PromptManager
from ...storage.metadata import DatabaseManager
from dbgpt.storage.metadata import DatabaseManager
from dbgpt.serve.core import BaseServe
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
@@ -20,7 +21,7 @@ from .models.prompt_template_adapter import PromptTemplateAdapter
logger = logging.getLogger(__name__)
class Serve(BaseComponent):
class Serve(BaseServe):
"""Serve component
Examples:
@@ -37,6 +38,7 @@ class Serve(BaseComponent):
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt")
system_app.on_init()
# Run before start hook
system_app.before_start()
@@ -61,6 +63,7 @@ class Serve(BaseComponent):
app = FastAPI()
system_app = SystemApp(app)
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
system_app.on_init()
# Run before start hook
system_app.before_start()
@@ -81,31 +84,41 @@ class Serve(BaseComponent):
self,
system_app: SystemApp,
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
tags: Optional[List[str]] = None,
api_tags: Optional[List[str]] = None,
db_url_or_db: Union[str, URL, DatabaseManager] = None,
try_create_tables: Optional[bool] = False,
):
if tags is None:
tags = [SERVE_APP_NAME_HUMP]
self._system_app = None
self._api_prefix = api_prefix
self._tags = tags
if api_tags is None:
api_tags = [SERVE_APP_NAME_HUMP]
super().__init__(
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._prompt_manager = None
self._db_url_or_db = db_url_or_db
self._try_create_tables = try_create_tables
self._db_manager: Optional[DatabaseManager] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._tags
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
@property
def prompt_manager(self) -> PromptManager:
"""Get the prompt manager of the serve app with db storage"""
return self._prompt_manager
def on_init(self):
"""Called before the start of the application.
You can do some initialization here.
"""
# import your own module here to ensure the module is loaded before the application starts
from .models.models import ServeEntity
def before_start(self):
"""Called before the start of the application.
@@ -113,23 +126,16 @@ class Serve(BaseComponent):
"""
# import your own module here to ensure the module is loaded before the application starts
from dbgpt.core.interface.prompt import PromptManager
from dbgpt.storage.metadata import Model, db
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity
init_db = self._db_url_or_db or db
init_db = DatabaseManager.build_from(init_db, base=Model)
if self._try_create_tables:
try:
init_db.create_all()
except Exception as e:
logger.warning(f"Failed to create tables: {e}")
self._db_manager = self.create_or_get_db_manager()
storage_adapter = PromptTemplateAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
init_db,
self._db_manager,
ServeEntity,
storage_adapter,
serializer,