mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
Merge Plugin
This commit is contained in:
commit
5a5fba5d18
@ -134,6 +134,7 @@ V100 | 16G |可以进行对话推理,有明显卡顿
|
|||||||
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
docker run --name=mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=aa12345678 -dit mysql:latest
|
||||||
```
|
```
|
||||||
向量数据库我们默认使用的是Chroma内存数据库,所以无需特殊安装,如果有需要连接其他的同学,可以按照我们的教程进行安装配置。整个DB-GPT的安装过程,我们使用的是miniconda3的虚拟环境。创建虚拟环境,并安装python依赖包
|
向量数据库我们默认使用的是Chroma内存数据库,所以无需特殊安装,如果有需要连接其他的同学,可以按照我们的教程进行安装配置。整个DB-GPT的安装过程,我们使用的是miniconda3的虚拟环境。创建虚拟环境,并安装python依赖包
|
||||||
|
|
||||||
```
|
```
|
||||||
python>=3.10
|
python>=3.10
|
||||||
conda create -n dbgpt_env python=3.10
|
conda create -n dbgpt_env python=3.10
|
||||||
|
@ -69,7 +69,7 @@ def execute_ai_response_json(
|
|||||||
arguments,
|
arguments,
|
||||||
prompt,
|
prompt,
|
||||||
)
|
)
|
||||||
result = f"Command {command_name} returned: " f"{command_result}"
|
result = f"{command_result}"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
import distro
|
import distro
|
||||||
import yaml
|
import yaml
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER
|
from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT
|
||||||
|
|
||||||
|
|
||||||
class AutoModePrompt:
|
class AutoModePrompt:
|
||||||
|
@ -62,14 +62,15 @@ priority = {
|
|||||||
"vicuna-13b": "aaa"
|
"vicuna-13b": "aaa"
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_simlar(q):
|
|
||||||
|
|
||||||
|
def get_simlar(q):
|
||||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||||
|
|
||||||
contents = [dc.page_content for dc, _ in docs]
|
contents = [dc.page_content for dc, _ in docs]
|
||||||
return "\n".join(contents)
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
def gen_sqlgen_conversation(dbname):
|
||||||
mo = MySQLOperator(
|
mo = MySQLOperator(
|
||||||
**DB_SETTINGS
|
**DB_SETTINGS
|
||||||
@ -82,10 +83,12 @@ def gen_sqlgen_conversation(dbname):
|
|||||||
message += s["schema_info"] + ";"
|
message += s["schema_info"] + ";"
|
||||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||||
|
|
||||||
|
|
||||||
def get_database_list():
|
def get_database_list():
|
||||||
mo = MySQLOperator(**DB_SETTINGS)
|
mo = MySQLOperator(**DB_SETTINGS)
|
||||||
return mo.get_db_list()
|
return mo.get_db_list()
|
||||||
|
|
||||||
|
|
||||||
get_window_url_params = """
|
get_window_url_params = """
|
||||||
function() {
|
function() {
|
||||||
const params = new URLSearchParams(window.location.search);
|
const params = new URLSearchParams(window.location.search);
|
||||||
@ -98,6 +101,8 @@ function() {
|
|||||||
return url_params;
|
return url_params;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def load_demo(url_params, request: gr.Request):
|
def load_demo(url_params, request: gr.Request):
|
||||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||||
|
|
||||||
@ -115,6 +120,7 @@ def load_demo(url_params, request: gr.Request):
|
|||||||
gr.Row.update(visible=True),
|
gr.Row.update(visible=True),
|
||||||
gr.Accordion.update(visible=True))
|
gr.Accordion.update(visible=True))
|
||||||
|
|
||||||
|
|
||||||
def get_conv_log_filename():
|
def get_conv_log_filename():
|
||||||
t = datetime.datetime.now()
|
t = datetime.datetime.now()
|
||||||
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||||
@ -127,9 +133,8 @@ def regenerate(state, request: gr.Request):
|
|||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
def clear_history(request: gr.Request):
|
def clear_history(request: gr.Request):
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"clear_history. ip: {request.client.host}")
|
logger.info(f"clear_history. ip: {request.client.host}")
|
||||||
state = None
|
state = None
|
||||||
return (state, [], "") + (disable_btn,) * 5
|
return (state, [], "") + (disable_btn,) * 5
|
||||||
@ -148,6 +153,7 @@ def add_text(state, text, request: gr.Request):
|
|||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
def post_process_code(code):
|
def post_process_code(code):
|
||||||
sep = "\n```"
|
sep = "\n```"
|
||||||
if sep in code:
|
if sep in code:
|
||||||
@ -158,6 +164,7 @@ def post_process_code(code):
|
|||||||
code = sep.join(blocks)
|
code = sep.join(blocks)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
print("AUTO DB-GPT模式.")
|
print("AUTO DB-GPT模式.")
|
||||||
@ -185,7 +192,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
# 第一轮对话需要加入提示Prompt
|
# 第一轮对话需要加入提示Prompt
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||||
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
|
||||||
|
db_schemes=gen_sqlgen_conversation(dbname))
|
||||||
logger.info("[TEST]:" + system_prompt)
|
logger.info("[TEST]:" + system_prompt)
|
||||||
template_name = "auto_dbgpt_one_shot"
|
template_name = "auto_dbgpt_one_shot"
|
||||||
new_state = conv_templates[template_name].copy()
|
new_state = conv_templates[template_name].copy()
|
||||||
@ -217,13 +225,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
state.messages[0][1] = ""
|
state.messages[0][1] = ""
|
||||||
state.messages[-2][1] = follow_up_prompt
|
state.messages[-2][1] = follow_up_prompt
|
||||||
|
|
||||||
|
|
||||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
knqa = KnownLedgeBaseQA()
|
knqa = KnownLedgeBaseQA()
|
||||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
state.messages[-2][1] = knqa.get_similar_answer(query)
|
||||||
|
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
@ -245,11 +251,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
print(response.json())
|
print(response.json())
|
||||||
print(str(response))
|
print(str(response))
|
||||||
try:
|
try:
|
||||||
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
|
|
||||||
# response = response.replace("\n", "\\n")
|
|
||||||
|
|
||||||
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
|
|
||||||
# response = response.replace("\n", "\\)
|
|
||||||
text = response.text.strip()
|
text = response.text.strip()
|
||||||
text = text.rstrip()
|
text = text.rstrip()
|
||||||
respObj = json.loads(text)
|
respObj = json.loads(text)
|
||||||
@ -279,11 +280,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
|||||||
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
|
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
|
||||||
cfg.set_last_plugin_return(plugin_resp)
|
cfg.set_last_plugin_return(plugin_resp)
|
||||||
print(plugin_resp)
|
print(plugin_resp)
|
||||||
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
|
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
except NotCommands as e:
|
except NotCommands as e:
|
||||||
print("命令执行:" + e.message)
|
print("命令执行:" + e.message)
|
||||||
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
|
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
else:
|
else:
|
||||||
# 流式输出
|
# 流式输出
|
||||||
@ -350,23 +351,26 @@ block_css = (
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def change_sql_mode(sql_mode):
|
def change_sql_mode(sql_mode):
|
||||||
if sql_mode in ["直接执行结果"]:
|
if sql_mode in ["直接执行结果"]:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
def change_tab():
|
def change_tab():
|
||||||
autogpt = True
|
autogpt = True
|
||||||
|
|
||||||
def build_single_model_ui():
|
|
||||||
|
|
||||||
|
def build_single_model_ui():
|
||||||
notice_markdown = """
|
notice_markdown = """
|
||||||
# DB-GPT
|
# DB-GPT
|
||||||
|
|
||||||
@ -398,7 +402,7 @@ def build_single_model_ui():
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
label="最大输出Token数",
|
label="最大输出Token数",
|
||||||
)
|
)
|
||||||
tabs= gr.Tabs()
|
tabs = gr.Tabs()
|
||||||
with tabs:
|
with tabs:
|
||||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||||
with tab_sql:
|
with tab_sql:
|
||||||
@ -414,9 +418,6 @@ def build_single_model_ui():
|
|||||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||||
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
|
||||||
with tab_auto:
|
|
||||||
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
|
||||||
|
|
||||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||||
with tab_qa:
|
with tab_qa:
|
||||||
@ -442,7 +443,6 @@ def build_single_model_ui():
|
|||||||
show_label=False)
|
show_label=False)
|
||||||
load_folder_button = gr.Button("上传并加载到知识库")
|
load_folder_button = gr.Button("上传并加载到知识库")
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks():
|
with gr.Blocks():
|
||||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -545,7 +545,6 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
logger.info("knowledge embedding success")
|
logger.info("knowledge embedding success")
|
||||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
@ -579,7 +578,7 @@ if __name__ == "__main__":
|
|||||||
for command_category in command_categories:
|
for command_category in command_categories:
|
||||||
command_registry.import_commands(command_category)
|
command_registry.import_commands(command_category)
|
||||||
|
|
||||||
cfg.command_registry =command_registry
|
cfg.command_registry = command_registry
|
||||||
|
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
|
Loading…
Reference in New Issue
Block a user