兼容autogpt插件模式

This commit is contained in:
tuyang.yhj
2023-05-12 16:12:28 +08:00
30 changed files with 773 additions and 790 deletions

View File

@@ -20,6 +20,9 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.prompt import build_default_prompt_generator
from pilot.prompts.first_conversation_prompt import FirstPrompt
from pilot.conversation import (
default_conversation,
@@ -28,7 +31,7 @@ from pilot.conversation import (
SeparatorStyle
)
from fastchat.utils import (
from pilot.utils import (
build_logger,
server_error_msg,
violates_moderation,
@@ -49,6 +52,7 @@ enable_moderation = False
models = []
dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
autogpt = False
priority = {
"vicuna-13b": "aaa"
@@ -62,8 +66,6 @@ def get_simlar(q):
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
@@ -122,6 +124,8 @@ def regenerate(state, request: gr.Request):
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
@@ -139,7 +143,7 @@ def add_text(state, text, request: gr.Request):
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
text = text[:4000] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
@@ -156,6 +160,8 @@ def post_process_code(code):
return code
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
print("是否是AUTO-GPT模式.", autogpt)
start_tstamp = time.time()
model_name = LLM_MODEL
@@ -166,7 +172,8 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
# 第一轮对话需要加入提示Prompt
@@ -255,29 +262,28 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def change_tab(tab):
pass
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def change_tab():
autogpt = True
def build_single_model_ui():
notice_markdown = """
@@ -305,16 +311,17 @@ def build_single_model_ui():
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
maximum=4096,
value=2048,
step=64,
interactive=True,
label="最大输出Token数",
)
tabs = gr.Tabs()
tabs= gr.Tabs()
with tabs:
with gr.TabItem("SQL生成与诊断", elem_id="SQL"):
# TODO A selector to choose database
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql:
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
@@ -322,9 +329,12 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
with tab_auto:
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
with gr.TabItem("知识问答", elem_id="QA"):
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
@@ -364,9 +374,7 @@ def build_single_model_ui():
regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
@@ -448,16 +456,16 @@ if __name__ == "__main__":
# 加载插件
cfg = Config()
cfg.plugins_dir = "123"
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
command_registry = CommandRegistry()
command_categories = [
"autogpt.commands.audio_text",
"autogpt.commands.file_operations",
"autogpt.commands.image_gen",
"autogpt.commands.web_selenium",
"autogpt.commands.write_tests",
"pilot.commands.audio_text",
"pilot.commands.image_gen",
]
# 排除禁用命令
command_categories = [
@@ -468,6 +476,13 @@ if __name__ == "__main__":
first_prompt =FirstPrompt(cfg= cfg)
first_prompt.command_registry = command_registry
system_prompt = first_prompt.construct_first_prompt( fisrt_message=["this is a test goal"])
logger.info("[TEST]:" + system_prompt)
logger.info(args)
demo = build_webdemo()
demo.queue(