ci: make ci happy lint the code, delete unused imports

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong0618
2023-05-24 18:42:55 +08:00
parent 562d5a98cc
commit b098a48898
75 changed files with 1110 additions and 824 deletions

View File

@@ -2,58 +2,55 @@
# -*- coding: utf-8 -*-
import argparse
import datetime
import json
import os
import shutil
import uuid
import json
import sys
import time
import gradio as gr
import datetime
import requests
import uuid
from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command import execute_ai_response_json
from pilot.commands.command_mange import CommandRegistry
from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.connections.mysql import MySQLOperator
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
conv_templates,
conversation_sql_mode,
conversation_types,
default_conversation,
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands
from pilot.conversation import (
default_conversation,
conv_templates,
conversation_types,
conversation_sql_mode,
SeparatorStyle, conv_qa_prompt_template
)
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.commands.command import execute_ai_response_json
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.utils import build_logger, server_error_msg
from pilot.vector_store.extract_tovec import (
get_vector_storelist,
knownledge_tovec_st,
load_knownledge_from_doc,
)
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -70,19 +67,19 @@ autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {
"vicuna-13b": "aaa"
}
priority = {"vicuna-13b": "aaa"}
# 加载插件
CFG= Config()
CFG = Config()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT
"port": CFG.LOCAL_DB_PORT,
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -92,9 +89,7 @@ def get_simlar(q):
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
)
mo = MySQLOperator(**DB_SETTINGS)
message = ""
@@ -132,13 +127,15 @@ def load_demo(url_params, request: gr.Request):
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def get_conv_log_filename():
@@ -185,7 +182,9 @@ def post_process_code(code):
return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
def http_bot(
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
@@ -212,12 +211,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
system_prompt = auto_prompt.construct_first_prompt(
fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname)
)
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
new_state.append_message(role='USER', message=system_prompt)
new_state.append_message(role="USER", message=system_prompt)
# new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
else:
@@ -226,7 +226,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
@@ -244,7 +246,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
@@ -268,17 +272,22 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
if mode == conversation_types["custome"] and not db_selector:
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {"vector_store_name": vector_store_name["vs_name"], "text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config)
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
@@ -290,7 +299,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
context = context[:2000]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
@@ -311,8 +320,12 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120)
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
print(response.json())
print(str(response))
@@ -321,17 +334,17 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
text = text.rstrip()
respObj = json.loads(text)
xx = respObj['response']
xx = xx.strip(b'\x00'.decode())
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0:
if respObj_ex["error_code"] == 0:
ai_response = None
all_text = respObj_ex['text']
all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find('ASSISTANT:') != -1:
if tmpResp[i].find("ASSISTANT:") != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("ASSISTANT:", "")
@@ -343,14 +356,20 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
state.messages[-1][-1] = "ASSISTANT未能正确回复回复结果为:\n" + all_text
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
plugin_resp = execute_ai_response_json(
auto_prompt.prompt_generator, ai_response
)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
state.messages[-1][-1] = (
"Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
state.messages[-1][-1] = (
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
@@ -359,8 +378,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try:
# Stream output
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
@@ -368,7 +392,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
@@ -381,12 +404,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
@@ -410,8 +444,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -487,7 +521,8 @@ def build_single_model_ui():
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True).style(container=False)
show_label=True,
).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
@@ -495,7 +530,9 @@ def build_single_model_ui():
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
)
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
@@ -504,19 +541,22 @@ def build_single_model_ui():
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False
)
files = gr.File(
label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False)
folder_files = gr.File(
label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
@@ -557,28 +597,32 @@ def build_single_model_ui():
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
[state, chatbot] + btn_list,
)
vs_add.click(
fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name]
)
load_file_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name],
)
load_folder_button.click(
fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name],
)
vs_add.click(fn=save_vs_name, show_progress=True,
inputs=[vs_name],
outputs=[vs_name])
load_file_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
outputs=[vs_name])
load_folder_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, folder_files],
outputs=[vs_name])
return state, chatbot, textbox, send_btn, button_row, parameter_row
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
@@ -613,26 +657,31 @@ def save_vs_name(vs_name):
vector_store_name["vs_name"] = vs_name
return vs_name
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id))
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename))
shutil.move(
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -671,5 +720,8 @@ if __name__ == "__main__":
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)