Files
DB-GPT/pilot/server/webserver.py
yhjun1026 dd5fc529e2 Merge branch 'main' into dev
# Conflicts:
#	pilot/connections/mysql.py
#	pilot/prompts/prompt_generator.py
#	pilot/scene/base_chat.py
#	pilot/scene/chat_db/chat.py
#	pilot/scene/chat_db/out_parser.py
#	pilot/scene/chat_execution/chat.py
#	pilot/server/webserver.py
2023-05-30 10:57:57 +08:00

786 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import datetime
import json
import os
import shutil
import sys
import time
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.commands.command_mange import CommandRegistry
from pilot.scene.base_chat import BaseChat
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_sql_mode,
conversation_types,
default_conversation,
)
from pilot.common.plugins import scan_plugins
from pilot.prompts.generator import PluginPromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
# 加载插件
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {"vicuna-13b": "aaa"}
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT,
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue"
)
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
]
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
message = ""
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
# dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
unique_id = uuid.uuid1()
state.conv_id = str(unique_id)
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
""" Default support 4000 tokens, if tokens too lang, we will cut off """
text = text[:4000]
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
### TODO
state.last_user_input = text
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
if "插件模式" == selected:
return ChatScene.ChatExecution
elif "知识问答" == selected:
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge
else:
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
return ChatScene.ChatNormal
def http_bot(
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
):
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
elif ChatScene.ChatExecution == scene:
logger.info("插件模式对话走新的模式!")
chat_param = {
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
new_state.conv_id = uuid.uuid4().hex
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
)
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep
if state.sep_style == SeparatorStyle.SINGLE
else state.sep2,
}
logger.info(f"Requert: \n{payload}")
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def change_sql_mode(sql_mode):
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def change_tab():
autogpt = True
def build_single_model_ui():
notice_markdown = get_lang_text("db_gpt_introduction")
learn_more_markdown = get_lang_text("learn_more_markdown")
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label=get_lang_text("max_input_token_size"),
)
tabs = gr.Tabs()
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return evt.value
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
tabs.select(on_select, None, selected)
with tabs:
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
with tab_sql:
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label=get_lang_text("please_choose_database"),
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True,
).style(container=False)
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label="请选择插件",
choices=list(plugins_select_info().keys()),
value="",
interactive=True,
show_label=True,
type="value"
).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
mode = gr.Radio(
[
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
],
show_label=False,
value=llm_native_dialogue,
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
)
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec:
gr.Markdown(get_lang_text("add_file_to_klg"))
with gr.Tab(get_lang_text("upload_file")):
files = gr.File(
label=get_lang_text("add_file"),
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Tab(get_lang_text("upload_folder")):
folder_files = gr.File(
label=get_lang_text("add_folder"),
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
vs_add.click(
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
)
load_file_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name],
)
load_folder_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name],
)
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui()
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params,
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
def save_vs_name(vs_name):
vector_store_name["vs_name"] = vs_name
return vs_name
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
# 配置初始化
cfg = Config()
dbs = cfg.local_db.get_database_list()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# 加载插件可执行命令
command_categories = [
"pilot.commands.audio_text",
"pilot.commands.image_gen",
]
# 排除禁用命令
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry
logger.info(args)
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)