Merge branch 'main' into dev

# Conflicts:
#	pilot/connections/mysql.py
#	pilot/prompts/prompt_generator.py
#	pilot/scene/base_chat.py
#	pilot/scene/chat_db/chat.py
#	pilot/scene/chat_db/out_parser.py
#	pilot/scene/chat_execution/chat.py
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026
2023-05-30 10:57:57 +08:00
21 changed files with 420 additions and 204 deletions

View File

@@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
@@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result
@staticmethod
def build_knowledge_prompt(query, docs, state):
prompt_template = PromptTemplate(
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
context = [d.page_content for d in docs]
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
if len(prompt) > 4000:
logger.info("prompt length greater than 4000, rebuild")
context = context[:2000]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
print("new prompt length:" + str(len(prompt)))
return prompt

View File

@@ -13,7 +13,7 @@ from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -56,7 +56,11 @@ from pilot.vector_store.extract_tovec import (
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
# 加载插件
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -67,15 +71,13 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {"vicuna-13b": "aaa"}
# 加载插件
CFG = Config()
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
@@ -86,6 +88,20 @@ DB_SETTINGS = {
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue"
)
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
]
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -100,7 +116,7 @@ def gen_sqlgen_conversation(dbname):
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info():
@@ -127,6 +143,7 @@ function() {
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
# dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
@@ -213,7 +230,7 @@ def http_bot(
):
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
@@ -222,7 +239,7 @@ def http_bot(
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"current_user_input": state.last_user_input,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
@@ -241,6 +258,7 @@ def http_bot(
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
@@ -290,18 +308,25 @@ def http_bot(
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
@@ -346,7 +371,56 @@ def http_bot(
enable_btn,
)
return
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
@@ -358,24 +432,24 @@ def http_bot(
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
@@ -396,14 +470,14 @@ block_css = (
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
@@ -413,27 +487,16 @@ def change_tab():
autogpt = True
def change_func(xx):
print("123")
print(str(xx))
def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目使用本地化的GPT大模型与您的数据和环境进行交互无数据泄露风险100% 私密100% 安全。
"""
learn_more_markdown = """
### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
"""
notice_markdown = get_lang_text("db_gpt_introduction")
learn_more_markdown = get_lang_text("learn_more_markdown")
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
@@ -449,7 +512,7 @@ def build_single_model_ui():
value=512,
step=64,
interactive=True,
label="最大输出Token数",
label=get_lang_text("max_input_token_size"),
)
tabs = gr.Tabs()
@@ -462,24 +525,30 @@ def build_single_model_ui():
tabs.select(on_select, None, selected)
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
tab_sql.select(on_select, None, None)
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
with tab_sql:
print("tab_sql in...")
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
label=get_lang_text("please_choose_database"),
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True,
).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
@@ -502,37 +571,50 @@ def build_single_model_ui():
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem("知识问答", elem_id="QA")
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
print("tab_qa in...")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
[
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
],
show_label=False,
value=llm_native_dialogue,
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False
)
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
vs_add = gr.Button("添加为新知识库")
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
)
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
gr.Markdown(get_lang_text("add_file_to_klg"))
with gr.Tab(get_lang_text("upload_file")):
files = gr.File(
label="添加文件",
label=get_lang_text("add_file"),
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
load_file_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Tab(get_lang_text("upload_folder")):
folder_files = gr.File(
label="添加文件夹",
label=get_lang_text("add_folder"),
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button("上传并加载到知识库")
load_folder_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
@@ -544,11 +626,11 @@ def build_single_model_ui():
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="发送", visible=False)
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False)
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
@@ -594,10 +676,10 @@ def build_single_model_ui():
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(