From 1280af19da1767989bc82d2a85af0c81a9e1b6c6 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 14 May 2023 21:34:02 +0800 Subject: [PATCH 1/3] update readme file --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 845e1ffa4..a209c30e9 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ A Open Database-GPT Experiment, A fully localized project. ## 项目方案 - + [DB-GPT](https://github.com/csunny/DB-GPT) is an experimental open-source application that builds upon the [FastChat](https://github.com/lm-sys/FastChat) model and uses vicuna as its base model. Additionally, it looks like this application incorporates langchain and llama-index embedding knowledge to improve Database-QA capabilities. @@ -20,25 +20,25 @@ Overall, it appears to be a sophisticated and innovative tool for working with d Run on an RTX 4090 GPU (The origin mov not sped up!, [YouTube地址](https://www.youtube.com/watch?v=1PWI6F89LPo)) - 运行演示 -![](https://github.com/csunny/DB-GPT/blob/main/asserts/演示.gif) +![](https://github.com/csunny/DB-GPT/blob/main/assers/演示.gif) - SQL生成示例 首先选择对应的数据库, 然后模型即可根据对应的数据库Schema信息生成SQL - + The Generated SQL is runable. - + - 数据库QA示例 - + 基于默认内置知识库QA - + # Dependencies 1. First you need to install python requirements. From d5b5fc4f9ab5fcd93564f13145ca0dbc0627e091 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Sun, 14 May 2023 21:59:40 +0800 Subject: [PATCH 2/3] bug fix 3 --- pilot/prompts/auto_mode_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/prompts/auto_mode_prompt.py b/pilot/prompts/auto_mode_prompt.py index 08595f4a9..ec5918582 100644 --- a/pilot/prompts/auto_mode_prompt.py +++ b/pilot/prompts/auto_mode_prompt.py @@ -7,7 +7,7 @@ from pathlib import Path import distro import yaml from pilot.configs.config import Config -from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER +from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT class AutoModePrompt: From e2750fcea012a4f30f742e464c580e521a2fbf1a Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Mon, 15 May 2023 14:18:08 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=8F=92=E4=BB=B6=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pilot/commands/command.py | 2 +- pilot/server/webserver.py | 84 ++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/pilot/commands/command.py b/pilot/commands/command.py index 84e384347..134e93e1d 100644 --- a/pilot/commands/command.py +++ b/pilot/commands/command.py @@ -69,7 +69,7 @@ def execute_ai_response_json( arguments, prompt, ) - result = f"Command {command_name} returned: " f"{command_result}" + result = f"{command_result}" return result diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 15caf5d40..522d0fe2d 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -20,7 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D from pilot.plugins import scan_plugins from pilot.configs.config import Config from pilot.commands.command_mange import CommandRegistry -from pilot.prompts.auto_mode_prompt import AutoModePrompt +from pilot.prompts.auto_mode_prompt import AutoModePrompt from pilot.prompts.generator import PromptGenerator from pilot.commands.exception_not_commands import NotCommands @@ -60,14 +60,15 @@ priority = { "vicuna-13b": "aaa" } + def get_simlar(q): - docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) contents = [dc.page_content for dc, _ in docs] return "\n".join(contents) - + + def gen_sqlgen_conversation(dbname): mo = MySQLOperator( **DB_SETTINGS @@ -80,10 +81,12 @@ def gen_sqlgen_conversation(dbname): message += s["schema_info"] + ";" return f"数据库{dbname}的Schema信息如下: {message}\n" + def get_database_list(): mo = MySQLOperator(**DB_SETTINGS) return mo.get_db_list() + get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); @@ -96,6 +99,8 @@ function() { return url_params; } """ + + def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") @@ -113,6 +118,7 @@ def load_demo(url_params, request: gr.Request): gr.Row.update(visible=True), gr.Accordion.update(visible=True)) + def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") @@ -125,9 +131,8 @@ def regenerate(state, request: gr.Request): state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 -def clear_history(request: gr.Request): - +def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 @@ -140,12 +145,13 @@ def add_text(state, text, request: gr.Request): return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 """ Default support 4000 tokens, if tokens too lang, we will cut off """ - text = text[:4000] + text = text[:4000] state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + def post_process_code(code): sep = "\n```" if sep in code: @@ -156,6 +162,7 @@ def post_process_code(code): code = sep.join(blocks) return code + def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: print("AUTO DB-GPT模式.") @@ -183,7 +190,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # 第一轮对话需要加入提示Prompt if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: # autogpt模式的第一轮对话需要 构建专属prompt - system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname)) + system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], + db_schemes=gen_sqlgen_conversation(dbname)) logger.info("[TEST]:" + system_prompt) template_name = "auto_dbgpt_one_shot" new_state = conv_templates[template_name].copy() @@ -210,20 +218,18 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # 第一轮对话需要加入提示Prompt if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: ## 获取最后一次插件的返回 - follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) + follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) state.messages[0][0] = "" state.messages[0][1] = "" state.messages[-2][1] = follow_up_prompt - if mode == conversation_types["default_knownledge"] and not db_selector: query = state.messages[-2][1] knqa = KnownLedgeBaseQA() state.messages[-2][1] = knqa.get_similar_answer(query) - prompt = state.get_prompt() - + skip_echo_len = len(prompt.replace("", " ")) + 1 # Make requests @@ -243,11 +249,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re print(response.json()) print(str(response)) try: - # response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}""" - # response = response.replace("\n", "\\n") - - # response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}""" - # response = response.replace("\n", "\\) text = response.text.strip() text = text.rstrip() respObj = json.loads(text) @@ -257,7 +258,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re respObj_ex = json.loads(xx) if respObj_ex['error_code'] == 0: ai_response = None - all_text = respObj_ex['text'] + all_text = respObj_ex['text'] ### 解析返回文本,获取AI回复部分 tmpResp = all_text.split(state.sep) last_index = -1 @@ -277,11 +278,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response) cfg.set_last_plugin_return(plugin_resp) print(plugin_resp) - state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp + state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 except NotCommands as e: print("命令执行:" + e.message) - state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response) + state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 else: # 流式输出 @@ -304,7 +305,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( - disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return except requests.exceptions.RequestException as e: @@ -333,8 +334,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -348,23 +349,26 @@ block_css = ( """ ) + def change_sql_mode(sql_mode): if sql_mode in ["直接执行结果"]: return gr.update(visible=True) else: return gr.update(visible=False) + def change_mode(mode): if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) + def change_tab(): - autogpt = True - + autogpt = True + + def build_single_model_ui(): - notice_markdown = """ # DB-GPT @@ -396,7 +400,7 @@ def build_single_model_ui(): interactive=True, label="最大输出Token数", ) - tabs= gr.Tabs() + tabs = gr.Tabs() with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") with tab_sql: @@ -427,7 +431,7 @@ def build_single_model_ui(): with gr.Column() as doc2vec: gr.Markdown("向知识库中添加文件") with gr.Tab("上传文件"): - files = gr.File(label="添加文件", + files = gr.File(label="添加文件", file_types=[".txt", ".md", ".docx", ".pdf"], file_count="multiple", show_label=False @@ -436,11 +440,10 @@ def build_single_model_ui(): load_file_button = gr.Button("上传并加载到知识库") with gr.Tab("上传文件夹"): folder_files = gr.File(label="添加文件", - file_count="directory", - show_label=False) + file_count="directory", + show_label=False) load_folder_button = gr.Button("上传并加载到知识库") - - + with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) with gr.Row(): @@ -449,9 +452,9 @@ def build_single_model_ui(): show_label=False, placeholder="Enter text and press ENTER", visible=False, - ).style(container=False) + ).style(container=False) with gr.Column(scale=2, min_width=50): - send_btn = gr.Button(value="发送", visible=False) + send_btn = gr.Button(value="发送", visible=False) with gr.Row(visible=False) as button_row: regenerate_btn = gr.Button(value="重新生成", interactive=False) @@ -465,7 +468,7 @@ def build_single_model_ui(): [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) - + textbox.submit( add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( @@ -487,10 +490,10 @@ def build_single_model_ui(): def build_webdemo(): with gr.Blocks( - title="数据库智能助手", - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title="数据库智能助手", + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) ( @@ -520,6 +523,7 @@ def build_webdemo(): raise ValueError(f"Unknown model list mode: {args.model_list_mode}") return demo + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") @@ -553,7 +557,7 @@ if __name__ == "__main__": for command_category in command_categories: command_registry.import_commands(command_category) - cfg.command_registry =command_registry + cfg.command_registry = command_registry logger.info(args) demo = build_webdemo() @@ -561,4 +565,4 @@ if __name__ == "__main__": concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200, - ) \ No newline at end of file + )