mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
兼容autogpt插件模式
This commit is contained in:
@@ -20,6 +20,9 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.prompt import build_default_prompt_generator
|
||||
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
@@ -28,7 +31,7 @@ from pilot.conversation import (
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
from fastchat.utils import (
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
violates_moderation,
|
||||
@@ -49,6 +52,7 @@ enable_moderation = False
|
||||
models = []
|
||||
dbs = []
|
||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||
autogpt = False
|
||||
|
||||
priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
@@ -62,8 +66,6 @@ def get_simlar(q):
|
||||
contents = [dc.page_content for dc, _ in docs]
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
@@ -122,6 +124,8 @@ def regenerate(state, request: gr.Request):
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
|
||||
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
@@ -139,7 +143,7 @@ def add_text(state, text, request: gr.Request):
|
||||
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
||||
no_change_btn,) * 5
|
||||
|
||||
text = text[:1536] # Hard cut-off
|
||||
text = text[:4000] # Hard cut-off
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
@@ -156,6 +160,8 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@@ -166,7 +172,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
|
||||
@@ -255,29 +262,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
white-space: -pre-wrap; /* Opera 4-6 */
|
||||
white-space: -o-pre-wrap; /* Opera 7 */
|
||||
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
||||
}
|
||||
#notice_markdown th {
|
||||
display: none;
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def change_tab(tab):
|
||||
pass
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
|
||||
|
||||
def change_tab():
|
||||
autogpt = True
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
@@ -305,16 +311,17 @@ def build_single_model_ui():
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
maximum=4096,
|
||||
value=2048,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
tabs = gr.Tabs()
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
|
||||
# TODO A selector to choose database
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
label="请选择数据库",
|
||||
@@ -322,9 +329,12 @@ def build_single_model_ui():
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||
with tab_auto:
|
||||
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
|
||||
with gr.TabItem("知识问答", elem_id="QA"):
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
@@ -364,9 +374,7 @@ def build_single_model_ui():
|
||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="清理", interactive=False)
|
||||
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
@@ -448,16 +456,16 @@ if __name__ == "__main__":
|
||||
|
||||
# 加载插件
|
||||
cfg = Config()
|
||||
|
||||
cfg.plugins_dir = "123"
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# 加载插件可执行命令
|
||||
command_registry = CommandRegistry()
|
||||
command_categories = [
|
||||
"autogpt.commands.audio_text",
|
||||
"autogpt.commands.file_operations",
|
||||
"autogpt.commands.image_gen",
|
||||
"autogpt.commands.web_selenium",
|
||||
"autogpt.commands.write_tests",
|
||||
"pilot.commands.audio_text",
|
||||
"pilot.commands.image_gen",
|
||||
]
|
||||
# 排除禁用命令
|
||||
command_categories = [
|
||||
@@ -468,6 +476,13 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
|
||||
first_prompt =FirstPrompt(cfg= cfg)
|
||||
first_prompt.command_registry = command_registry
|
||||
|
||||
system_prompt = first_prompt.construct_first_prompt( fisrt_message=["this is a test goal"])
|
||||
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
|
Reference in New Issue
Block a user