Files
DB-GPT/pilot/server/webserver.py
aries_ckt fa73a4fae6 feat: new knowledge chat scene
1.add new knowledge chat scene
2.format file format
2023-06-27 16:50:16 +08:00

729 lines
24 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import threading
import traceback
import argparse
import datetime
import os
import shutil
import sys
import uuid
import gradio as gr
from pilot.embedding_engine.knowledge_type import KnowledgeType
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.scene.base_chat import BaseChat
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.conversation import (
conversation_sql_mode,
conversation_types,
chat_mode_title,
default_conversation,
)
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
)
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
from pilot.server.webserver_base import server_init
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
# db_summary = {"dbsummary": ""}
priority = {"vicuna-13b": "aaa"}
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT,
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue"
)
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
url_knowledge_dialogue = get_lang_text("knowledge_qa_type_url_knowledge_dialogue")
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
origins = ["*"]
# 添加跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
message = ""
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s + ";"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
gradioURL = window.location.href
if (!gradioURL.endsWith('?__theme=dark')) {
window.location.replace(gradioURL + '?__theme=dark');
}
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
# dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
unique_id = uuid.uuid1()
state.conv_id = str(unique_id)
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
""" Default support 4000 tokens, if tokens too lang, we will cut off """
text = text[:4000]
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
### TODO
state.last_user_input = text
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def get_chat_mode(selected, param=None) -> ChatScene:
if chat_mode_title["chat_use_plugin"] == selected:
return ChatScene.ChatExecution
elif chat_mode_title["sql_generate_diagnostics"] == selected:
sql_mode = param
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
return ChatScene.ChatWithDbExecute
else:
return ChatScene.ChatWithDbQA
else:
mode = param
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatDefaultKnowledge
elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge
elif mode == conversation_types["url"]:
return ChatScene.ChatUrlKnowledge
else:
return ChatScene.ChatNormal
def chatbot_callback(state, message):
print(f"chatbot_callback:{message}")
state.messages[-1][-1] = f"{message}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
def http_bot(
state,
selected,
temperature,
max_new_tokens,
plugin_selector,
mode,
sql_mode,
db_selector,
url_input,
knowledge_name,
):
logger.info(
f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}"
)
if chat_mode_title["sql_generate_diagnostics"] == selected:
scene: ChatScene = get_chat_mode(selected, sql_mode)
elif chat_mode_title["chat_use_plugin"] == selected:
scene: ChatScene = get_chat_mode(selected)
else:
scene: ChatScene = get_chat_mode(selected, mode)
print(f"chat scene:{scene.value}")
if ChatScene.ChatWithDbExecute == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
elif ChatScene.ChatWithDbQA == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
elif ChatScene.ChatExecution == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
elif ChatScene.ChatNormal == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
elif ChatScene.ChatDefaultKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
elif ChatScene.ChatNewKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"knowledge_name": knowledge_name,
}
elif ChatScene.ChatUrlKnowledge == scene:
chat_param = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
"url": url_input,
}
else:
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
if not chat.prompt_template.stream_out:
logger.info("not stream out, wait model response!")
state.messages[-1][-1] = chat.nostream_call()
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
logger.info("stream out start!")
try:
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
state.messages[-1][-1] =msg
chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def change_sql_mode(sql_mode):
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in [add_knowledge_base_dialogue]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def build_single_model_ui():
notice_markdown = get_lang_text("db_gpt_introduction")
learn_more_markdown = get_lang_text("learn_more_markdown")
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Temperature",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label=get_lang_text("max_input_token_size"),
)
tabs = gr.Tabs()
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return evt.value
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
tabs.select(on_select, None, selected)
with tabs:
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
mode = gr.Radio(
[
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
url_knowledge_dialogue,
],
show_label=False,
value=llm_native_dialogue,
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False, visible=False
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(
label=get_lang_text("url_input_label"),
lines=1,
interactive=True,
visible=False,
)
def show_url_input(evt: gr.SelectData):
if evt.value == url_knowledge_dialogue:
return gr.update(visible=True)
else:
return gr.update(visible=False)
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
with vs_setting:
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
)
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec:
gr.Markdown(get_lang_text("add_file_to_klg"))
with gr.Tab(get_lang_text("upload_file")):
files = gr.File(
label=get_lang_text("add_file"),
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Tab(get_lang_text("upload_folder")):
folder_files = gr.File(
label=get_lang_text("add_folder"),
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
with tab_sql:
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label=get_lang_text("please_choose_database"),
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True,
).style(container=False)
# db_selector.change(fn=db_selector_changed, inputs=db_selector)
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label=get_lang_text("select_plugin"),
choices=list(plugins_select_info().keys()),
value="",
interactive=True,
show_label=True,
type="value",
).style(container=False)
def plugin_change(
evt: gr.SelectData,
): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
print(f"user plugin:{plugins_select_info().get(evt.value)}")
return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(
show_label=False, visible=False, placeholder="Selected"
)
plugin_selector.select(plugin_change, None, plugin_selected)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, temperature, max_output_tokens] + params,
[state, chatbot] + btn_list,
)
vs_add.click(
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
)
load_file_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name],
)
load_folder_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name],
)
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui()
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[
state,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params,
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
def save_vs_name(vs_name):
vector_store_name["vs_name"] = vs_name
return vs_name
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return vs_id
def async_db_summery():
client = DBSummaryClient()
thread = threading.Thread(target=client.init_db_summary)
thread.start()
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)