mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
feat(core): Support multi round conversation operator (#986)
This commit is contained in:
@@ -43,6 +43,15 @@ class ServeRequest(BaseModel):
|
||||
"You are a data analysis expert.",
|
||||
],
|
||||
)
|
||||
prompt_desc: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt description.",
|
||||
examples=[
|
||||
"This is a prompt for code assistant.",
|
||||
"This is a prompt for joker.",
|
||||
"This is a prompt for data analysis expert.",
|
||||
],
|
||||
)
|
||||
|
||||
user_name: Optional[str] = Field(
|
||||
None,
|
||||
|
@@ -48,6 +48,7 @@ class ServeEntity(Model):
|
||||
default="f-string",
|
||||
comment="Prompt format(eg: f-string, jinja2)",
|
||||
)
|
||||
prompt_desc = Column(String(512), nullable=True, comment="Prompt description")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -96,6 +97,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
prompt_type=entity.prompt_type,
|
||||
prompt_name=entity.prompt_name,
|
||||
content=entity.content,
|
||||
prompt_desc=entity.prompt_desc,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
@@ -119,6 +121,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
prompt_type=entity.prompt_type,
|
||||
prompt_name=entity.prompt_name,
|
||||
content=entity.content,
|
||||
prompt_desc=entity.prompt_desc,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
|
@@ -3,10 +3,11 @@ from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core import PromptManager
|
||||
|
||||
from ...storage.metadata import DatabaseManager
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from .api.endpoints import init_endpoints, router
|
||||
from .config import (
|
||||
APP_NAME,
|
||||
@@ -20,7 +21,7 @@ from .models.prompt_template_adapter import PromptTemplateAdapter
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
class Serve(BaseServe):
|
||||
"""Serve component
|
||||
|
||||
Examples:
|
||||
@@ -37,6 +38,7 @@ class Serve(BaseComponent):
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt")
|
||||
system_app.on_init()
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
@@ -61,6 +63,7 @@ class Serve(BaseComponent):
|
||||
app = FastAPI()
|
||||
system_app = SystemApp(app)
|
||||
system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True)
|
||||
system_app.on_init()
|
||||
# Run before start hook
|
||||
system_app.before_start()
|
||||
|
||||
@@ -81,31 +84,41 @@ class Serve(BaseComponent):
|
||||
self,
|
||||
system_app: SystemApp,
|
||||
api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}",
|
||||
tags: Optional[List[str]] = None,
|
||||
api_tags: Optional[List[str]] = None,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager] = None,
|
||||
try_create_tables: Optional[bool] = False,
|
||||
):
|
||||
if tags is None:
|
||||
tags = [SERVE_APP_NAME_HUMP]
|
||||
self._system_app = None
|
||||
self._api_prefix = api_prefix
|
||||
self._tags = tags
|
||||
if api_tags is None:
|
||||
api_tags = [SERVE_APP_NAME_HUMP]
|
||||
super().__init__(
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._prompt_manager = None
|
||||
self._db_url_or_db = db_url_or_db
|
||||
self._try_create_tables = try_create_tables
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
return
|
||||
self._system_app = system_app
|
||||
self._system_app.app.include_router(
|
||||
router, prefix=self._api_prefix, tags=self._tags
|
||||
router, prefix=self._api_prefix, tags=self._api_tags
|
||||
)
|
||||
init_endpoints(self._system_app)
|
||||
self._app_has_initiated = True
|
||||
|
||||
@property
|
||||
def prompt_manager(self) -> PromptManager:
|
||||
"""Get the prompt manager of the serve app with db storage"""
|
||||
return self._prompt_manager
|
||||
|
||||
def on_init(self):
|
||||
"""Called before the start of the application.
|
||||
|
||||
You can do some initialization here.
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from .models.models import ServeEntity
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application.
|
||||
|
||||
@@ -113,23 +126,16 @@ class Serve(BaseComponent):
|
||||
"""
|
||||
# import your own module here to ensure the module is loaded before the application starts
|
||||
from dbgpt.core.interface.prompt import PromptManager
|
||||
from dbgpt.storage.metadata import Model, db
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.models import ServeEntity
|
||||
|
||||
init_db = self._db_url_or_db or db
|
||||
init_db = DatabaseManager.build_from(init_db, base=Model)
|
||||
if self._try_create_tables:
|
||||
try:
|
||||
init_db.create_all()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create tables: {e}")
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
storage_adapter = PromptTemplateAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
init_db,
|
||||
self._db_manager,
|
||||
ServeEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
|
Reference in New Issue
Block a user