mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
Merge branch 'main' into dev
# Conflicts: # pilot/connections/mysql.py # pilot/prompts/prompt_generator.py # pilot/scene/base_chat.py # pilot/scene/chat_db/chat.py # pilot/scene/chat_db/out_parser.py # pilot/scene/chat_execution/chat.py # pilot/server/webserver.py
This commit is contained in:
@@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
|
||||
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from pilot.logs import logger
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
|
||||
@@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def build_knowledge_prompt(query, docs, state):
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
|
||||
if len(prompt) > 4000:
|
||||
logger.info("prompt length greater than 4000, rebuild")
|
||||
context = context[:2000]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
print("new prompt length:" + str(len(prompt)))
|
||||
|
||||
return prompt
|
||||
|
@@ -13,7 +13,7 @@ from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from langchain import PromptTemplate
|
||||
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@@ -56,7 +56,11 @@ from pilot.vector_store.extract_tovec import (
|
||||
from pilot.commands.command import execute_ai_response_json
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.language.translation_handler import get_lang_text
|
||||
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
||||
@@ -67,15 +71,13 @@ disable_btn = gr.Button.update(interactive=True)
|
||||
enable_moderation = False
|
||||
models = []
|
||||
dbs = []
|
||||
vs_list = ["新建知识库"] + get_vector_storelist()
|
||||
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
|
||||
autogpt = False
|
||||
vector_store_client = None
|
||||
vector_store_name = {"vs_name": ""}
|
||||
|
||||
priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
|
||||
DB_SETTINGS = {
|
||||
@@ -86,6 +88,20 @@ DB_SETTINGS = {
|
||||
}
|
||||
|
||||
|
||||
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
||||
default_knowledge_base_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_default_knowledge_base_dialogue"
|
||||
)
|
||||
add_knowledge_base_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
||||
)
|
||||
knowledge_qa_type_list = [
|
||||
llm_native_dialogue,
|
||||
default_knowledge_base_dialogue,
|
||||
add_knowledge_base_dialogue,
|
||||
]
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
@@ -100,7 +116,7 @@ def gen_sqlgen_conversation(dbname):
|
||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
@@ -127,6 +143,7 @@ function() {
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
|
||||
# dbs = get_database_list()
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if dbs:
|
||||
gr.Dropdown.update(choices=dbs)
|
||||
@@ -213,7 +230,7 @@ def http_bot(
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
|
||||
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
@@ -222,7 +239,7 @@ def http_bot(
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"current_user_input": state.last_user_input,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
@@ -241,6 +258,7 @@ def http_bot(
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
else:
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
@@ -290,18 +308,25 @@ def http_bot(
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, 1)
|
||||
context = [d.page_content for d in docs]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"],
|
||||
print("vector store name: ", vector_store_name["vs_name"])
|
||||
vector_store_config = {
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
prompt = state.get_prompt()
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
payload = {
|
||||
"model": model_name,
|
||||
@@ -346,7 +371,56 @@ def http_bot(
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
) * 5
|
||||
else:
|
||||
output = (
|
||||
data["text"] + f" (error_code: {data['error_code']})"
|
||||
)
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
@@ -358,24 +432,24 @@ def http_bot(
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
|
||||
block_css = (
|
||||
@@ -396,14 +470,14 @@ block_css = (
|
||||
|
||||
|
||||
def change_sql_mode(sql_mode):
|
||||
if sql_mode in ["直接执行结果"]:
|
||||
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
@@ -413,27 +487,16 @@ def change_tab():
|
||||
autogpt = True
|
||||
|
||||
|
||||
def change_func(xx):
|
||||
print("123")
|
||||
print(str(xx))
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目,使用本地化的GPT大模型与您的数据和环境进行交互,无数据泄露风险,100% 私密,100% 安全。
|
||||
"""
|
||||
learn_more_markdown = """
|
||||
### Licence
|
||||
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
|
||||
"""
|
||||
notice_markdown = get_lang_text("db_gpt_introduction")
|
||||
learn_more_markdown = get_lang_text("learn_more_markdown")
|
||||
|
||||
state = gr.State()
|
||||
|
||||
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
||||
|
||||
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
|
||||
with gr.Accordion(
|
||||
get_lang_text("model_control_param"), open=False, visible=False
|
||||
) as parameter_row:
|
||||
temperature = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
@@ -449,7 +512,7 @@ def build_single_model_ui():
|
||||
value=512,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
label=get_lang_text("max_input_token_size"),
|
||||
)
|
||||
|
||||
tabs = gr.Tabs()
|
||||
@@ -462,24 +525,30 @@ def build_single_model_ui():
|
||||
tabs.select(on_select, None, selected)
|
||||
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
tab_sql.select(on_select, None, None)
|
||||
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
|
||||
with tab_sql:
|
||||
print("tab_sql in...")
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
label="请选择数据库",
|
||||
label=get_lang_text("please_choose_database"),
|
||||
choices=dbs,
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode = gr.Radio(
|
||||
[
|
||||
get_lang_text("sql_generate_mode_direct"),
|
||||
get_lang_text("sql_generate_mode_none"),
|
||||
],
|
||||
show_label=False,
|
||||
value=get_lang_text("sql_generate_mode_none"),
|
||||
)
|
||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
|
||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
||||
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
||||
# tab_plugin.select(change_func)
|
||||
with tab_plugin:
|
||||
@@ -502,37 +571,50 @@ def build_single_model_ui():
|
||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
||||
with tab_qa:
|
||||
print("tab_qa in...")
|
||||
mode = gr.Radio(
|
||||
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
||||
[
|
||||
llm_native_dialogue,
|
||||
default_knowledge_base_dialogue,
|
||||
add_knowledge_base_dialogue,
|
||||
],
|
||||
show_label=False,
|
||||
value=llm_native_dialogue,
|
||||
)
|
||||
vs_setting = gr.Accordion(
|
||||
get_lang_text("configure_knowledge_base"), open=False
|
||||
)
|
||||
vs_setting = gr.Accordion("配置知识库", open=False)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
with vs_setting:
|
||||
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
|
||||
vs_add = gr.Button("添加为新知识库")
|
||||
vs_name = gr.Textbox(
|
||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
||||
)
|
||||
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
|
||||
with gr.Column() as doc2vec:
|
||||
gr.Markdown("向知识库中添加文件")
|
||||
with gr.Tab("上传文件"):
|
||||
gr.Markdown(get_lang_text("add_file_to_klg"))
|
||||
with gr.Tab(get_lang_text("upload_file")):
|
||||
files = gr.File(
|
||||
label="添加文件",
|
||||
label=get_lang_text("add_file"),
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
allow_flagged_uploads=True,
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
load_file_button = gr.Button("上传并加载到知识库")
|
||||
with gr.Tab("上传文件夹"):
|
||||
load_file_button = gr.Button(
|
||||
get_lang_text("upload_and_load_to_klg")
|
||||
)
|
||||
with gr.Tab(get_lang_text("upload_folder")):
|
||||
folder_files = gr.File(
|
||||
label="添加文件夹",
|
||||
label=get_lang_text("add_folder"),
|
||||
accept_multiple_files=True,
|
||||
file_count="directory",
|
||||
show_label=False,
|
||||
)
|
||||
load_folder_button = gr.Button("上传并加载到知识库")
|
||||
load_folder_button = gr.Button(
|
||||
get_lang_text("upload_and_load_to_klg")
|
||||
)
|
||||
|
||||
with gr.Blocks():
|
||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||
@@ -544,11 +626,11 @@ def build_single_model_ui():
|
||||
visible=False,
|
||||
).style(container=False)
|
||||
with gr.Column(scale=2, min_width=50):
|
||||
send_btn = gr.Button(value="发送", visible=False)
|
||||
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
|
||||
|
||||
with gr.Row(visible=False) as button_row:
|
||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||
clear_btn = gr.Button(value="清理", interactive=False)
|
||||
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
|
||||
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
@@ -594,10 +676,10 @@ def build_single_model_ui():
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title=get_lang_text("database_smart_assistant"),
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
|
Reference in New Issue
Block a user