Merge Plugin

This commit is contained in:
csunny 2023-05-15 22:16:28 +08:00
commit 5a5fba5d18
4 changed files with 46 additions and 46 deletions

View File

@ -134,6 +134,7 @@ V100 | 16G |可以进行对话推理,有明显卡顿
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
``` ```
向量数据库我们默认使用的是Chroma内存数据库所以无需特殊安装如果有需要连接其他的同学可以按照我们的教程进行安装配置。整个DB-GPT的安装过程我们使用的是miniconda3的虚拟环境。创建虚拟环境并安装python依赖包 向量数据库我们默认使用的是Chroma内存数据库所以无需特殊安装如果有需要连接其他的同学可以按照我们的教程进行安装配置。整个DB-GPT的安装过程我们使用的是miniconda3的虚拟环境。创建虚拟环境并安装python依赖包
``` ```
python>=3.10 python>=3.10
conda create -n dbgpt_env python=3.10 conda create -n dbgpt_env python=3.10

View File

@ -69,7 +69,7 @@ def execute_ai_response_json(
arguments, arguments,
prompt, prompt,
) )
result = f"Command {command_name} returned: " f"{command_result}" result = f"{command_result}"
return result return result

View File

@ -7,7 +7,7 @@ from pathlib import Path
import distro import distro
import yaml import yaml
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT
class AutoModePrompt: class AutoModePrompt:

View File

@ -62,14 +62,15 @@ priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
} }
def get_simlar(q):
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1) docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs] contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents) return "\n".join(contents)
def gen_sqlgen_conversation(dbname): def gen_sqlgen_conversation(dbname):
mo = MySQLOperator( mo = MySQLOperator(
**DB_SETTINGS **DB_SETTINGS
@ -82,10 +83,12 @@ def gen_sqlgen_conversation(dbname):
message += s["schema_info"] + ";" message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n" return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list(): def get_database_list():
mo = MySQLOperator(**DB_SETTINGS) mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list() return mo.get_db_list()
get_window_url_params = """ get_window_url_params = """
function() { function() {
const params = new URLSearchParams(window.location.search); const params = new URLSearchParams(window.location.search);
@ -98,6 +101,8 @@ function() {
return url_params; return url_params;
} }
""" """
def load_demo(url_params, request: gr.Request): def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
@ -115,6 +120,7 @@ def load_demo(url_params, request: gr.Request):
gr.Row.update(visible=True), gr.Row.update(visible=True),
gr.Accordion.update(visible=True)) gr.Accordion.update(visible=True))
def get_conv_log_filename(): def get_conv_log_filename():
t = datetime.datetime.now() t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
@ -127,9 +133,8 @@ def regenerate(state, request: gr.Request):
state.skip_next = False state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request): def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}") logger.info(f"clear_history. ip: {request.client.host}")
state = None state = None
return (state, [], "") + (disable_btn,) * 5 return (state, [], "") + (disable_btn,) * 5
@ -148,6 +153,7 @@ def add_text(state, text, request: gr.Request):
state.skip_next = False state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code): def post_process_code(code):
sep = "\n```" sep = "\n```"
if sep in code: if sep in code:
@ -158,6 +164,7 @@ def post_process_code(code):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.") print("AUTO DB-GPT模式.")
@ -185,7 +192,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt # 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt # autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname)) system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt) logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot" template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy() new_state = conv_templates[template_name].copy()
@ -217,13 +225,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
state.messages[0][1] = "" state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt state.messages[-2][1] = follow_up_prompt
if mode == conversation_types["default_knownledge"] and not db_selector: if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1] query = state.messages[-2][1]
knqa = KnownLedgeBaseQA() knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query) state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt() prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1 skip_echo_len = len(prompt.replace("</s>", " ")) + 1
@ -245,11 +251,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print(response.json()) print(response.json())
print(str(response)) print(str(response))
try: try:
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
# response = response.replace("\n", "\\n")
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
# response = response.replace("\n", "\\)
text = response.text.strip() text = response.text.strip()
text = text.rstrip() text = text.rstrip()
respObj = json.loads(text) respObj = json.loads(text)
@ -279,11 +280,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response) plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp) cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp) print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e: except NotCommands as e:
print("命令执行:" + e.message) print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response) state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else: else:
# 流式输出 # 流式输出
@ -350,23 +351,26 @@ block_css = (
""" """
) )
def change_sql_mode(sql_mode): def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]: if sql_mode in ["直接执行结果"]:
return gr.update(visible=True) return gr.update(visible=True)
else: else:
return gr.update(visible=False) return gr.update(visible=False)
def change_mode(mode): def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]: if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False) return gr.update(visible=False)
else: else:
return gr.update(visible=True) return gr.update(visible=True)
def change_tab(): def change_tab():
autogpt = True autogpt = True
def build_single_model_ui():
def build_single_model_ui():
notice_markdown = """ notice_markdown = """
# DB-GPT # DB-GPT
@ -398,7 +402,7 @@ def build_single_model_ui():
interactive=True, interactive=True,
label="最大输出Token数", label="最大输出Token数",
) )
tabs= gr.Tabs() tabs = gr.Tabs()
with tabs: with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql: with tab_sql:
@ -414,9 +418,6 @@ def build_single_model_ui():
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果") sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
with tab_auto:
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
tab_qa = gr.TabItem("知识问答", elem_id="QA") tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa: with tab_qa:
@ -442,7 +443,6 @@ def build_single_model_ui():
show_label=False) show_label=False)
load_folder_button = gr.Button("上传并加载到知识库") load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks(): with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row(): with gr.Row():
@ -545,7 +545,6 @@ def knowledge_embedding_store(vs_id, files):
logger.info("knowledge embedding success") logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -579,7 +578,7 @@ if __name__ == "__main__":
for command_category in command_categories: for command_category in command_categories:
command_registry.import_commands(command_category) command_registry.import_commands(command_category)
cfg.command_registry =command_registry cfg.command_registry = command_registry
logger.info(args) logger.info(args)
demo = build_webdemo() demo = build_webdemo()