mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-20 01:19:54 +00:00
156 lines
4.9 KiB
Python
156 lines
4.9 KiB
Python
import signal
|
|
import os
|
|
import threading
|
|
import sys
|
|
from typing import Optional, Any
|
|
from dataclasses import dataclass, field
|
|
|
|
from pilot.configs.config import Config
|
|
from pilot.component import SystemApp
|
|
from pilot.utils.parameter_utils import BaseParameters
|
|
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
|
|
|
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
sys.path.append(ROOT_PATH)
|
|
|
|
|
|
def signal_handler(sig, frame):
|
|
print("in order to avoid chroma db atexit problem")
|
|
os._exit(0)
|
|
|
|
|
|
def async_db_summary(system_app: SystemApp):
|
|
"""async db schema into vector db"""
|
|
from pilot.summary.db_summary_client import DBSummaryClient
|
|
|
|
client = DBSummaryClient(system_app=system_app)
|
|
thread = threading.Thread(target=client.init_db_summary)
|
|
thread.start()
|
|
|
|
|
|
def server_init(param: "WebWerverParameters", system_app: SystemApp):
|
|
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
|
|
|
# logger.info(f"args: {args}")
|
|
|
|
# init config
|
|
cfg = Config()
|
|
cfg.SYSTEM_APP = system_app
|
|
|
|
ddl_init_and_upgrade(param.disable_alembic_upgrade)
|
|
|
|
# load_native_plugins(cfg)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
# Loader plugins and commands
|
|
command_categories = [
|
|
"pilot.base_modules.agent.commands.built_in.audio_text",
|
|
"pilot.base_modules.agent.commands.built_in.image_gen",
|
|
]
|
|
# exclude commands
|
|
command_categories = [
|
|
x for x in command_categories if x not in cfg.disabled_command_categories
|
|
]
|
|
command_registry = CommandRegistry()
|
|
for command_category in command_categories:
|
|
command_registry.import_commands(command_category)
|
|
|
|
cfg.command_registry = command_registry
|
|
|
|
command_disply_commands = [
|
|
"pilot.base_modules.agent.commands.disply_type.show_chart_gen",
|
|
"pilot.base_modules.agent.commands.disply_type.show_table_gen",
|
|
"pilot.base_modules.agent.commands.disply_type.show_text_gen",
|
|
]
|
|
command_disply_registry = CommandRegistry()
|
|
for command in command_disply_commands:
|
|
command_disply_registry.import_commands(command)
|
|
cfg.command_disply = command_disply_registry
|
|
|
|
|
|
def _create_model_start_listener(system_app: SystemApp):
|
|
from pilot.connections.manages.connection_manager import ConnectManager
|
|
|
|
cfg = Config()
|
|
|
|
def startup_event(wh):
|
|
# init connect manage
|
|
print("begin run _add_app_startup_event")
|
|
conn_manage = ConnectManager(system_app)
|
|
cfg.LOCAL_DB_MANAGE = conn_manage
|
|
async_db_summary(system_app)
|
|
|
|
return startup_event
|
|
|
|
|
|
@dataclass
|
|
class WebWerverParameters(BaseParameters):
|
|
host: Optional[str] = field(
|
|
default="0.0.0.0", metadata={"help": "Webserver deploy host"}
|
|
)
|
|
port: Optional[int] = field(
|
|
default=5000, metadata={"help": "Webserver deploy port"}
|
|
)
|
|
daemon: Optional[bool] = field(
|
|
default=False, metadata={"help": "Run Webserver in background"}
|
|
)
|
|
controller_addr: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`."
|
|
},
|
|
)
|
|
model_name: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.",
|
|
"tags": "fixed",
|
|
},
|
|
)
|
|
share: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
|
},
|
|
)
|
|
remote_embedding: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`"
|
|
},
|
|
)
|
|
log_level: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Logging level",
|
|
"valid_values": [
|
|
"FATAL",
|
|
"ERROR",
|
|
"WARNING",
|
|
"WARNING",
|
|
"INFO",
|
|
"DEBUG",
|
|
"NOTSET",
|
|
],
|
|
},
|
|
)
|
|
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
|
log_file: Optional[str] = field(
|
|
default="dbgpt_webserver.log",
|
|
metadata={
|
|
"help": "The filename to store log",
|
|
},
|
|
)
|
|
tracer_file: Optional[str] = field(
|
|
default="dbgpt_webserver_tracer.jsonl",
|
|
metadata={
|
|
"help": "The filename to store tracer span records",
|
|
},
|
|
)
|
|
disable_alembic_upgrade: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to disable alembic to initialize and upgrade database metadata",
|
|
},
|
|
)
|