mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 13:29:35 +00:00
插件启动接入
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
print("Setting random seed to 42")
|
||||
random.seed(42)
|
||||
|
||||
# Load the users .env file into environment variables
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
||||
|
@@ -20,8 +20,6 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
|
||||
from pilot.conversation import (
|
||||
@@ -39,6 +37,8 @@ from pilot.utils import (
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
@@ -172,12 +172,11 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
if(autogpt):
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
@@ -456,7 +455,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
dbs = get_database_list()
|
||||
# dbs = get_database_list()
|
||||
|
||||
# 加载插件
|
||||
cfg = Config()
|
||||
@@ -464,7 +463,6 @@ if __name__ == "__main__":
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
command_registry = CommandRegistry()
|
||||
command_categories = [
|
||||
"pilot.commands.audio_text",
|
||||
"pilot.commands.image_gen",
|
||||
@@ -473,11 +471,11 @@ if __name__ == "__main__":
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry =command_category
|
||||
|
||||
cfg.command_registry =command_registry
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
|
Reference in New Issue
Block a user