#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import datetime import json import os import shutil import sys import time import uuid from urllib.parse import urljoin import gradio as gr import requests from langchain import PromptTemplate ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) from pilot.commands.command_mange import CommandRegistry from pilot.scene.base_chat import BaseChat from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, LOGDIR, VECTOR_SEARCH_TOP_K, ) from pilot.connections.mysql import MySQLOperator from pilot.conversation import ( SeparatorStyle, conv_qa_prompt_template, conv_templates, conversation_sql_mode, conversation_types, default_conversation, ) from pilot.plugins import scan_plugins from pilot.prompts.auto_mode_prompt import AutoModePrompt from pilot.prompts.generator import PromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.utils import build_logger, server_error_msg from pilot.vector_store.extract_tovec import ( get_vector_storelist, knownledge_tovec_st, load_knownledge_from_doc, ) from pilot.commands.command import execute_ai_response_json from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} no_change_btn = gr.Button.update() enable_btn = gr.Button.update(interactive=True) disable_btn = gr.Button.update(interactive=True) enable_moderation = False models = [] dbs = [] vs_list = ["新建知识库"] + get_vector_storelist() autogpt = False vector_store_client = None vector_store_name = {"vs_name": ""} priority = {"vicuna-13b": "aaa"} # 加载插件 CFG = Config() CHAT_FACTORY = ChatFactory() DB_SETTINGS = { "user": CFG.LOCAL_DB_USER, "password": CFG.LOCAL_DB_PASSWORD, "host": CFG.LOCAL_DB_HOST, "port": CFG.LOCAL_DB_PORT, } def get_simlar(q): docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) contents = [dc.page_content for dc, _ in docs] return "\n".join(contents) def gen_sqlgen_conversation(dbname): mo = MySQLOperator(**DB_SETTINGS) message = "" schemas = mo.get_schema(dbname) for s in schemas: message += s["schema_info"] + ";" return f"数据库{dbname}的Schema信息如下: {message}\n" def get_database_list(): mo = MySQLOperator(**DB_SETTINGS) return mo.get_db_list() get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); gradioURL = window.location.href if (!gradioURL.endsWith('?__theme=dark')) { window.location.replace(gradioURL + '?__theme=dark'); } return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") # dbs = get_database_list() dropdown_update = gr.Dropdown.update(visible=True) if dbs: gr.Dropdown.update(choices=dbs) state = default_conversation.copy() unique_id = uuid.uuid1() state.conv_id = str(unique_id) return ( state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True), ) def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def regenerate(state, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 def add_text(state, text, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 """ Default support 4000 tokens, if tokens too lang, we will cut off """ text = text[:4000] state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False ### TODO state.last_user_input = text return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def post_process_code(code): sep = "\n```" if sep in code: blocks = code.split(sep) if len(blocks) % 2 == 1: for i in range(1, len(blocks), 2): blocks[i] = blocks[i].replace("\\_", "_") code = sep.join(blocks) return code def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: if mode == conversation_types["default_knownledge"] and not db_selector: return ChatScene.ChatKnowledge elif mode == conversation_types["custome"] and not db_selector: return ChatScene.ChatNewKnowledge elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: return ChatScene.ChatWithDb elif mode == conversation_types["auto_execute_plugin"] and not db_selector: return ChatScene.ChatExecution else: return ChatScene.ChatNormal def http_bot( state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request ): logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") start_tstamp = time.time() scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector) print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL if ChatScene.ChatWithDb == scene: logger.info("基于DB对话走新的模式!") chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) chat.call() state.messages[-1][-1] = f"{chat.current_ai_response()}" yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 else: dbname = db_selector # TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化 if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: query = state.messages[-2][1] template_name = "conv_one_shot" new_state = conv_templates[template_name].copy() # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: new_state.append_message( new_state.roles[0], gen_sqlgen_conversation(dbname) + query ) new_state.append_message(new_state.roles[1], None) else: new_state.append_message(new_state.roles[0], query) new_state.append_message(new_state.roles[1], None) new_state.conv_id = uuid.uuid4().hex state = new_state prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["default_knownledge"] and not db_selector: query = state.messages[-2][1] knqa = KnownLedgeBaseQA() state.messages[-2][1] = knqa.get_similar_answer(query) prompt = state.get_prompt() state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["custome"] and not db_selector: persist_dir = os.path.join( KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb" ) print("向量数据库持久化地址: ", persist_dir) knowledge_embedding_client = KnowledgeEmbedding( file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], local_persist=False, vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, }, ) query = state.messages[-2][1] docs = knowledge_embedding_client.similar_search(query, 1) context = [d.page_content for d in docs] prompt_template = PromptTemplate( template=conv_qa_prompt_template, input_variables=["context", "question"], ) result = prompt_template.format(context="\n".join(context), question=query) state.messages[-2][1] = result prompt = state.get_prompt() state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 # Make requests payload = { "model": model_name, "prompt": prompt, "temperature": float(temperature), "max_new_tokens": int(max_new_tokens), "stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2, } logger.info(f"Requert: \n{payload}") # 流式输出 state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output response = requests.post( urljoin(CFG.MODEL_SERVER, "generate_stream"), headers=headers, json=payload, stream=True, timeout=20, ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][skip_echo_len:].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 # 记录运行日志 finish_tstamp = time.time() logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "start": round(start_tstamp, 4), "finish": round(start_tstamp, 4), "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") block_css = ( code_highlight_css + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -pre-wrap; /* Opera 4-6 */ white-space: -o-pre-wrap; /* Opera 7 */ word-wrap: break-word; /* Internet Explorer 5.5+ */ } #notice_markdown th { display: none; } """ ) def change_sql_mode(sql_mode): if sql_mode in ["直接执行结果"]: return gr.update(visible=True) else: return gr.update(visible=False) def change_mode(mode): if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) def change_tab(): autogpt = True def build_single_model_ui(): notice_markdown = """ # DB-GPT [DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密,100% 安全。 """ learn_more_markdown = """ ### Licence The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B """ state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Accordion("参数", open=False, visible=False) as parameter_row: temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Temperature", ) max_output_tokens = gr.Slider( minimum=0, maximum=1024, value=512, step=64, interactive=True, label="最大输出Token数", ) tabs = gr.Tabs() with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") with tab_sql: # TODO A selector to choose database with gr.Row(elem_id="db_selector"): db_selector = gr.Dropdown( label="请选择数据库", choices=dbs, value=dbs[0] if len(models) > 0 else "", interactive=True, show_label=True, ).style(container=False) sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果") sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) tab_qa = gr.TabItem("知识问答", elem_id="QA") with tab_qa: mode = gr.Radio( ["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话" ) vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) with vs_setting: vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True) vs_add = gr.Button("添加为新知识库") with gr.Column() as doc2vec: gr.Markdown("向知识库中添加文件") with gr.Tab("上传文件"): files = gr.File( label="添加文件", file_types=[".txt", ".md", ".docx", ".pdf"], file_count="multiple", allow_flagged_uploads=True, show_label=False, ) load_file_button = gr.Button("上传并加载到知识库") with gr.Tab("上传文件夹"): folder_files = gr.File( label="添加文件夹", accept_multiple_files=True, file_count="directory", show_label=False, ) load_folder_button = gr.Button("上传并加载到知识库") with gr.Blocks(): chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) with gr.Row(): with gr.Column(scale=20): textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", visible=False, ).style(container=False) with gr.Column(scale=2, min_width=50): send_btn = gr.Button(value="发送", visible=False) with gr.Row(visible=False) as button_row: regenerate_btn = gr.Button(value="重新生成", interactive=False) clear_btn = gr.Button(value="清理", interactive=False) gr.Markdown(learn_more_markdown) btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) textbox.submit( add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) send_btn.click( add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) vs_add.click( fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name] ) load_file_button.click( fn=knowledge_embedding_store, show_progress=True, inputs=[vs_name, files], outputs=[vs_name], ) load_folder_button.click( fn=knowledge_embedding_store, show_progress=True, inputs=[vs_name, folder_files], outputs=[vs_name], ) return state, chatbot, textbox, send_btn, button_row, parameter_row def build_webdemo(): with gr.Blocks( title="数据库智能助手", # theme=gr.themes.Base(), theme=gr.themes.Default(), css=block_css, ) as demo: url_params = gr.JSON(visible=False) ( state, chatbot, textbox, send_btn, button_row, parameter_row, ) = build_single_model_ui() if args.model_list_mode == "once": demo.load( load_demo, [url_params], [ state, chatbot, textbox, send_btn, button_row, parameter_row, ], _js=get_window_url_params, ) else: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") return demo def save_vs_name(vs_name): vector_store_name["vs_name"] = vs_name return vs_name def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)) for file in files: filename = os.path.split(file.name)[-1] shutil.move( file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) ) knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, }, ) knowledge_embedding_client.knowledge_embedding() logger.info("knowledge embedding success") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument( "--model-list-mode", type=str, default="once", choices=["once", "reload"] ) parser.add_argument("--share", default=False, action="store_true") args = parser.parse_args() logger.info(f"args: {args}") # 配置初始化 cfg = Config() dbs = cfg.local_db.get_database_list() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 command_categories = [ "pilot.commands.audio_text", "pilot.commands.image_gen", ] # 排除禁用命令 command_categories = [ x for x in command_categories if x not in cfg.disabled_command_categories ] command_registry = CommandRegistry() for command_category in command_categories: command_registry.import_commands(command_category) cfg.command_registry = command_registry logger.info(args) demo = build_webdemo() demo.queue( concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200, )