mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(core): Support multi round conversation operator (#986)
This commit is contained in:
@@ -40,7 +40,7 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
|
||||
cfg = Config()
|
||||
cfg.SYSTEM_APP = system_app
|
||||
# Initialize db storage first
|
||||
_initialize_db_storage(param)
|
||||
_initialize_db_storage(param, system_app)
|
||||
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
@@ -86,12 +86,14 @@ def _create_model_start_listener(system_app: SystemApp):
|
||||
return startup_event
|
||||
|
||||
|
||||
def _initialize_db_storage(param: "WebServerParameters"):
|
||||
def _initialize_db_storage(param: "WebServerParameters", system_app: SystemApp):
|
||||
"""Initialize the db storage.
|
||||
|
||||
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||
"""
|
||||
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)
|
||||
_initialize_db(
|
||||
try_to_create_db=not param.disable_alembic_upgrade, system_app=system_app
|
||||
)
|
||||
|
||||
|
||||
def _migration_db_storage(param: "WebServerParameters"):
|
||||
@@ -114,7 +116,9 @@ def _migration_db_storage(param: "WebServerParameters"):
|
||||
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)
|
||||
|
||||
|
||||
def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
def _initialize_db(
|
||||
try_to_create_db: Optional[bool] = False, system_app: Optional[SystemApp] = None
|
||||
) -> str:
|
||||
"""Initialize the database
|
||||
|
||||
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
|
||||
@@ -147,7 +151,11 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
initialize_db(db_url, db_name, engine_args)
|
||||
db = initialize_db(db_url, db_name, engine_args)
|
||||
if system_app:
|
||||
from dbgpt.storage.metadata import UnifiedDBManagerFactory
|
||||
|
||||
system_app.register(UnifiedDBManagerFactory, db)
|
||||
return default_meta_data_path
|
||||
|
||||
|
||||
@@ -273,3 +281,9 @@ class WebServerParameters(BaseParameters):
|
||||
"help": "Whether to disable alembic to initialize and upgrade database metadata",
|
||||
},
|
||||
)
|
||||
awel_dirs: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The directories to search awel files, split by `,`",
|
||||
},
|
||||
)
|
||||
|
@@ -46,9 +46,9 @@ def initialize_components(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app)
|
||||
_initialize_awel(system_app, param)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app)
|
||||
register_serve_apps(system_app, CFG)
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp):
|
||||
@@ -64,8 +64,14 @@ def _initialize_model_cache(system_app: SystemApp):
|
||||
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
|
||||
|
||||
|
||||
def _initialize_awel(system_app: SystemApp):
|
||||
def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
|
||||
from dbgpt.core.awel import initialize_awel
|
||||
from dbgpt.configs.model_config import _DAG_DEFINITION_DIR
|
||||
|
||||
initialize_awel(system_app, _DAG_DEFINITION_DIR)
|
||||
# Add default dag definition dir
|
||||
dag_dirs = [_DAG_DEFINITION_DIR]
|
||||
if param.awel_dirs:
|
||||
dag_dirs += param.awel_dirs.strip().split(",")
|
||||
dag_dirs = [x.strip() for x in dag_dirs]
|
||||
|
||||
initialize_awel(system_app, dag_dirs)
|
||||
|
@@ -146,14 +146,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
mount_routers(app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
system_app.on_init()
|
||||
|
||||
# Before start, after initialize_components
|
||||
# TODO: initialize_worker_manager_in_client as a component register in system_app
|
||||
system_app.before_start()
|
||||
# Migration db storage, so you db models must be imported before this
|
||||
_migration_db_storage(param)
|
||||
|
||||
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
|
||||
# TODO: initialize_worker_manager_in_client as a component register in system_app
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
@@ -186,6 +185,9 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
CFG.SERVER_LIGHT_MODE = True
|
||||
|
||||
mount_static_files(app)
|
||||
|
||||
# Before start, after on_init
|
||||
system_app.before_start()
|
||||
return param
|
||||
|
||||
|
||||
|
@@ -1,13 +1,28 @@
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
|
||||
def register_serve_apps(system_app: SystemApp):
|
||||
def register_serve_apps(system_app: SystemApp, cfg: Config):
|
||||
"""Register serve apps"""
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe, SERVE_CONFIG_KEY_PREFIX
|
||||
system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE)
|
||||
|
||||
# ################################ Prompt Serve Register Begin ######################################
|
||||
from dbgpt.serve.prompt.serve import (
|
||||
Serve as PromptServe,
|
||||
SERVE_CONFIG_KEY_PREFIX as PROMPT_SERVE_CONFIG_KEY_PREFIX,
|
||||
)
|
||||
|
||||
# Replace old prompt serve
|
||||
# Set config
|
||||
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
|
||||
system_app.config.set(f"{SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
|
||||
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_user", "dbgpt")
|
||||
system_app.config.set(f"{PROMPT_SERVE_CONFIG_KEY_PREFIX}default_sys_code", "dbgpt")
|
||||
# Register serve app
|
||||
system_app.register(PromptServe, api_prefix="/prompt")
|
||||
# ################################ Prompt Serve Register End ########################################
|
||||
|
||||
# ################################ Conversation Serve Register Begin ######################################
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
|
||||
# Register serve app
|
||||
system_app.register(ConversationServe)
|
||||
# ################################ Conversation Serve Register End ########################################
|
||||
|
@@ -217,6 +217,10 @@ async def dialogue_list(
|
||||
model_name = item.get("model_name", CFG.LLM_MODEL)
|
||||
user_name = item.get("user_name")
|
||||
sys_code = item.get("sys_code")
|
||||
if not item.get("messages"):
|
||||
# Skip the empty messages
|
||||
# TODO support new conversation and message mode
|
||||
continue
|
||||
|
||||
messages = json.loads(item.get("messages"))
|
||||
last_round = max(messages, key=lambda x: x["chat_order"])
|
||||
|
Reference in New Issue
Block a user