diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 5a29dfea1..c6ba341c4 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -18,9 +18,10 @@ import requests ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.commands.command import execute_ai_response_json from pilot.commands.command_mange import CommandRegistry -from pilot.commands.exception_not_commands import NotCommands + +from pilot.scene.base_chat import BaseChat + from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, @@ -29,7 +30,6 @@ from pilot.configs.model_config import ( LOGDIR, VECTOR_SEARCH_TOP_K, ) -from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator from pilot.conversation import ( SeparatorStyle, @@ -41,15 +41,22 @@ from pilot.conversation import ( ) from pilot.plugins import scan_plugins from pilot.prompts.auto_mode_prompt import AutoModePrompt +from pilot.prompts.generator import PromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot +from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.utils import build_logger, server_error_msg from pilot.vector_store.extract_tovec import ( get_vector_storelist, knownledge_tovec_st, + load_knownledge_from_doc, ) +from pilot.commands.command import execute_ai_response_json +from pilot.scene.base import ChatScene +from pilot.scene.chat_factory import ChatFactory + logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} @@ -69,6 +76,7 @@ priority = {"vicuna-13b": "aaa"} # 加载插件 CFG = Config() +CHAT_FACTORY = ChatFactory() DB_SETTINGS = { "user": CFG.LOCAL_DB_USER, @@ -125,6 +133,10 @@ def load_demo(url_params, request: gr.Request): gr.Dropdown.update(choices=dbs) state = default_conversation.copy() + + unique_id = uuid.uuid1() + state.conv_id = str(unique_id) + return ( state, dropdown_update, @@ -166,6 +178,8 @@ def add_text(state, text, request: gr.Request): state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False + ### TODO + state.last_user_input = text return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 @@ -180,255 +194,271 @@ def post_process_code(code): return code +def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: + if mode == conversation_types["default_knownledge"] and not db_selector: + return ChatScene.ChatKnowledge + elif mode == conversation_types["custome"] and not db_selector: + return ChatScene.ChatNewKnowledge + elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: + return ChatScene.ChatWithDb + + elif mode == conversation_types["auto_execute_plugin"] and not db_selector: + return ChatScene.ChatExecution + else: + return ChatScene.ChatNormal + + def http_bot( state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request ): - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - print("AUTO DB-GPT模式.") - if sql_mode == conversation_sql_mode["dont_execute_ai_response"]: - print("标准DB-GPT模式.") - print("是否是AUTO-GPT模式.", autogpt) - + logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") start_tstamp = time.time() + scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector) + print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL - dbname = db_selector - # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 - if state.skip_next: - # This generate call is skipped due to invalid inputs - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - return - - cfg = Config() - auto_prompt = AutoModePrompt() - auto_prompt.command_registry = cfg.command_registry - - # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. - if len(state.messages) == state.offset + 2: - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - # autogpt模式的第一轮对话需要 构建专属prompt - system_prompt = auto_prompt.construct_first_prompt( - fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname) - ) - logger.info("[TEST]:" + system_prompt) - template_name = "auto_dbgpt_one_shot" - new_state = conv_templates[template_name].copy() - new_state.append_message(role="USER", message=system_prompt) - # new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - else: - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - - new_state.conv_id = uuid.uuid4().hex - state = new_state - else: - ### 后续对话 - query = state.messages[-2][1] - # 第一轮对话需要加入提示Prompt - if mode == conversation_types["custome"]: - template_name = "conv_one_shot" - new_state = conv_templates[template_name].copy() - # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? - # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 - if db_selector: - new_state.append_message( - new_state.roles[0], gen_sqlgen_conversation(dbname) + query - ) - new_state.append_message(new_state.roles[1], None) - else: - new_state.append_message(new_state.roles[0], query) - new_state.append_message(new_state.roles[1], None) - state = new_state - elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - ## 获取最后一次插件的返回 - follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) - state.messages[0][0] = "" - state.messages[0][1] = "" - state.messages[-2][1] = follow_up_prompt - prompt = state.get_prompt() - skip_echo_len = len(prompt.replace("", " ")) + 1 - if mode == conversation_types["default_knownledge"] and not db_selector: - vector_store_config = { - "vector_store_name": "default", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + if ChatScene.ChatWithDb == scene: + logger.info("基于DB对话走新的模式!") + chat_param = { + "chat_session_id": state.conv_id, + "db_name": db_selector, + "user_input": state.last_user_input, } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - if mode == conversation_types["custome"] and not db_selector: - print("vector store name: ", vector_store_name["vs_name"]) - vector_store_config = { - "vector_store_name": vector_store_name["vs_name"], - "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, - } - knowledge_embedding_client = KnowledgeEmbedding( - file_path="", - model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config, - ) - query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) - - state.messages[-2][1] = query - skip_echo_len = len(prompt.replace("", " ")) + 1 - - # Make requests - payload = { - "model": model_name, - "prompt": prompt, - "temperature": float(temperature), - "max_new_tokens": int(max_new_tokens), - "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, - } - logger.info(f"Requert: \n{payload}") - - if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) - - print(response.json()) - print(str(response)) - try: - text = response.text.strip() - text = text.rstrip() - respObj = json.loads(text) - - xx = respObj["response"] - xx = xx.strip(b"\x00".decode()) - respObj_ex = json.loads(xx) - if respObj_ex["error_code"] == 0: - ai_response = None - all_text = respObj_ex["text"] - ### 解析返回文本,获取AI回复部分 - tmpResp = all_text.split(state.sep) - last_index = -1 - for i in range(len(tmpResp)): - if tmpResp[i].find("ASSISTANT:") != -1: - last_index = i - ai_response = tmpResp[last_index] - ai_response = ai_response.replace("ASSISTANT:", "") - ai_response = ai_response.replace("\n", "") - ai_response = ai_response.replace("\_", "_") - - print(ai_response) - if ai_response == None: - state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - else: - plugin_resp = execute_ai_response_json( - auto_prompt.prompt_generator, ai_response - ) - cfg.set_last_plugin_return(plugin_resp) - print(plugin_resp) - state.messages[-1][-1] = ( - "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp - ) - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - except NotCommands as e: - print("命令执行:" + e.message) - state.messages[-1][-1] = ( - "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) - ) - yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 - else: - # 流式输出 - state.messages[-1][-1] = "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 - - try: - # Stream output - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=20, - ) - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) - - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. - """ - if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - output = data["text"][skip_echo_len:].strip() - else: - output = data["text"].strip() - - output = post_process_code(output) - state.messages[-1][-1] = output + "▌" - yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 - else: - output = data["text"] + f" (error_code: {data['error_code']})" - state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - except requests.exceptions.RequestException as e: - state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + ( - disable_btn, - disable_btn, - disable_btn, - enable_btn, - enable_btn, - ) - return - - state.messages[-1][-1] = state.messages[-1][-1][:-1] + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + chat.call() + state.messages[-1][-1] = f"{chat.current_ai_response()}" yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 - # 记录运行日志 - finish_tstamp = time.time() - logger.info(f"{output}") + else: + dbname = db_selector + # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return - with open(get_conv_log_filename(), "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "start": round(start_tstamp, 4), - "finish": round(start_tstamp, 4), - "state": state.dict(), - "ip": request.client.host, + cfg = Config() + auto_prompt = AutoModePrompt() + auto_prompt.command_registry = cfg.command_registry + + # TODO when tab mode is AUTO_GPT, Prompt need to rebuild. + if len(state.messages) == state.offset + 2: + query = state.messages[-2][1] + # 第一轮对话需要加入提示Prompt + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + # autogpt模式的第一轮对话需要 构建专属prompt + system_prompt = auto_prompt.construct_first_prompt( + fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname) + ) + logger.info("[TEST]:" + system_prompt) + template_name = "auto_dbgpt_one_shot" + new_state = conv_templates[template_name].copy() + new_state.append_message(role="USER", message=system_prompt) + # new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + else: + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 + if db_selector: + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) + new_state.append_message(new_state.roles[1], None) + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + + new_state.conv_id = uuid.uuid4().hex + state = new_state + else: + ### 后续对话 + query = state.messages[-2][1] + # 第一轮对话需要加入提示Prompt + if mode == conversation_types["custome"]: + template_name = "conv_one_shot" + new_state = conv_templates[template_name].copy() + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 + if db_selector: + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) + new_state.append_message(new_state.roles[1], None) + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state + elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + ## 获取最后一次插件的返回 + follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) + state.messages[0][0] = "" + state.messages[0][1] = "" + state.messages[-2][1] = follow_up_prompt + prompt = state.get_prompt() + skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["default_knownledge"] and not db_selector: + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } - fout.write(json.dumps(data) + "\n") + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 + + if mode == conversation_types["custome"] and not db_selector: + print("vector store name: ", vector_store_name["vs_name"]) + vector_store_config = { + "vector_store_name": vector_store_name["vs_name"], + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) + + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 + + # Make requests + payload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "max_new_tokens": int(max_new_tokens), + "stop": state.sep + if state.sep_style == SeparatorStyle.SINGLE + else state.sep2, + } + logger.info(f"Requert: \n{payload}") + + if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + + print(response.json()) + print(str(response)) + try: + text = response.text.strip() + text = text.rstrip() + respObj = json.loads(text) + + xx = respObj["response"] + xx = xx.strip(b"\x00".decode()) + respObj_ex = json.loads(xx) + if respObj_ex["error_code"] == 0: + ai_response = None + all_text = respObj_ex["text"] + ### 解析返回文本,获取AI回复部分 + tmpResp = all_text.split(state.sep) + last_index = -1 + for i in range(len(tmpResp)): + if tmpResp[i].find("ASSISTANT:") != -1: + last_index = i + ai_response = tmpResp[last_index] + ai_response = ai_response.replace("ASSISTANT:", "") + ai_response = ai_response.replace("\n", "") + ai_response = ai_response.replace("\_", "_") + + print(ai_response) + if ai_response == None: + state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + else: + plugin_resp = execute_ai_response_json( + auto_prompt.prompt_generator, ai_response + ) + cfg.set_last_plugin_return(plugin_resp) + print(plugin_resp) + state.messages[-1][-1] = ( + "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp + ) + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + except NotCommands as e: + print("命令执行:" + e.message) + state.messages[-1][-1] = ( + "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) + ) + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + else: + # 流式输出 + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=20, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + + """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """ + if data["error_code"] == 0: + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len:].strip() + else: + output = data["text"].strip() + + output = post_process_code(output) + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + # 记录运行日志 + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") block_css = ( @@ -685,7 +715,8 @@ if __name__ == "__main__": # 配置初始化 cfg = Config() - # dbs = get_database_list() + dbs = cfg.local_db.get_database_list() + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令