mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 10:05:13 +00:00
update:merge dev
This commit is contained in:
@@ -18,9 +18,10 @@ import requests
|
|||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
from pilot.commands.command import execute_ai_response_json
|
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
from pilot.commands.exception_not_commands import NotCommands
|
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
DATASETS_DIR,
|
DATASETS_DIR,
|
||||||
@@ -29,7 +30,6 @@ from pilot.configs.model_config import (
|
|||||||
LOGDIR,
|
LOGDIR,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
)
|
)
|
||||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
|
||||||
from pilot.connections.mysql import MySQLOperator
|
from pilot.connections.mysql import MySQLOperator
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
SeparatorStyle,
|
SeparatorStyle,
|
||||||
@@ -41,15 +41,22 @@ from pilot.conversation import (
|
|||||||
)
|
)
|
||||||
from pilot.plugins import scan_plugins
|
from pilot.plugins import scan_plugins
|
||||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||||
|
from pilot.prompts.generator import PromptGenerator
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
|
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||||
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
|
||||||
from pilot.utils import build_logger, server_error_msg
|
from pilot.utils import build_logger, server_error_msg
|
||||||
from pilot.vector_store.extract_tovec import (
|
from pilot.vector_store.extract_tovec import (
|
||||||
get_vector_storelist,
|
get_vector_storelist,
|
||||||
knownledge_tovec_st,
|
knownledge_tovec_st,
|
||||||
|
load_knownledge_from_doc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pilot.commands.command import execute_ai_response_json
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
|
||||||
@@ -69,6 +76,7 @@ priority = {"vicuna-13b": "aaa"}
|
|||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
CHAT_FACTORY = ChatFactory()
|
||||||
|
|
||||||
DB_SETTINGS = {
|
DB_SETTINGS = {
|
||||||
"user": CFG.LOCAL_DB_USER,
|
"user": CFG.LOCAL_DB_USER,
|
||||||
@@ -125,6 +133,10 @@ def load_demo(url_params, request: gr.Request):
|
|||||||
gr.Dropdown.update(choices=dbs)
|
gr.Dropdown.update(choices=dbs)
|
||||||
|
|
||||||
state = default_conversation.copy()
|
state = default_conversation.copy()
|
||||||
|
|
||||||
|
unique_id = uuid.uuid1()
|
||||||
|
state.conv_id = str(unique_id)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
state,
|
state,
|
||||||
dropdown_update,
|
dropdown_update,
|
||||||
@@ -166,6 +178,8 @@ def add_text(state, text, request: gr.Request):
|
|||||||
state.append_message(state.roles[0], text)
|
state.append_message(state.roles[0], text)
|
||||||
state.append_message(state.roles[1], None)
|
state.append_message(state.roles[1], None)
|
||||||
state.skip_next = False
|
state.skip_next = False
|
||||||
|
### TODO
|
||||||
|
state.last_user_input = text
|
||||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
@@ -180,255 +194,271 @@ def post_process_code(code):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||||
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
|
return ChatScene.ChatKnowledge
|
||||||
|
elif mode == conversation_types["custome"] and not db_selector:
|
||||||
|
return ChatScene.ChatNewKnowledge
|
||||||
|
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||||
|
return ChatScene.ChatWithDb
|
||||||
|
|
||||||
|
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||||
|
return ChatScene.ChatExecution
|
||||||
|
else:
|
||||||
|
return ChatScene.ChatNormal
|
||||||
|
|
||||||
|
|
||||||
def http_bot(
|
def http_bot(
|
||||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||||
):
|
):
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||||
print("AUTO DB-GPT模式.")
|
|
||||||
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
|
|
||||||
print("标准DB-GPT模式.")
|
|
||||||
print("是否是AUTO-GPT模式.", autogpt)
|
|
||||||
|
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
|
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||||
|
print(f"当前对话模式:{scene.value}")
|
||||||
model_name = CFG.LLM_MODEL
|
model_name = CFG.LLM_MODEL
|
||||||
|
|
||||||
dbname = db_selector
|
if ChatScene.ChatWithDb == scene:
|
||||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
logger.info("基于DB对话走新的模式!")
|
||||||
if state.skip_next:
|
chat_param = {
|
||||||
# This generate call is skipped due to invalid inputs
|
"chat_session_id": state.conv_id,
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
"db_name": db_selector,
|
||||||
return
|
"user_input": state.last_user_input,
|
||||||
|
|
||||||
cfg = Config()
|
|
||||||
auto_prompt = AutoModePrompt()
|
|
||||||
auto_prompt.command_registry = cfg.command_registry
|
|
||||||
|
|
||||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
|
||||||
if len(state.messages) == state.offset + 2:
|
|
||||||
query = state.messages[-2][1]
|
|
||||||
# 第一轮对话需要加入提示Prompt
|
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
|
||||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
|
||||||
system_prompt = auto_prompt.construct_first_prompt(
|
|
||||||
fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname)
|
|
||||||
)
|
|
||||||
logger.info("[TEST]:" + system_prompt)
|
|
||||||
template_name = "auto_dbgpt_one_shot"
|
|
||||||
new_state = conv_templates[template_name].copy()
|
|
||||||
new_state.append_message(role="USER", message=system_prompt)
|
|
||||||
# new_state.append_message(new_state.roles[0], query)
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
|
||||||
else:
|
|
||||||
template_name = "conv_one_shot"
|
|
||||||
new_state = conv_templates[template_name].copy()
|
|
||||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
|
||||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
|
||||||
if db_selector:
|
|
||||||
new_state.append_message(
|
|
||||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
|
||||||
)
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
|
||||||
else:
|
|
||||||
new_state.append_message(new_state.roles[0], query)
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
|
||||||
|
|
||||||
new_state.conv_id = uuid.uuid4().hex
|
|
||||||
state = new_state
|
|
||||||
else:
|
|
||||||
### 后续对话
|
|
||||||
query = state.messages[-2][1]
|
|
||||||
# 第一轮对话需要加入提示Prompt
|
|
||||||
if mode == conversation_types["custome"]:
|
|
||||||
template_name = "conv_one_shot"
|
|
||||||
new_state = conv_templates[template_name].copy()
|
|
||||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
|
||||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
|
||||||
if db_selector:
|
|
||||||
new_state.append_message(
|
|
||||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
|
||||||
)
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
|
||||||
else:
|
|
||||||
new_state.append_message(new_state.roles[0], query)
|
|
||||||
new_state.append_message(new_state.roles[1], None)
|
|
||||||
state = new_state
|
|
||||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
|
||||||
## 获取最后一次插件的返回
|
|
||||||
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
|
||||||
state.messages[0][0] = ""
|
|
||||||
state.messages[0][1] = ""
|
|
||||||
state.messages[-2][1] = follow_up_prompt
|
|
||||||
prompt = state.get_prompt()
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
|
||||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
|
||||||
vector_store_config = {
|
|
||||||
"vector_store_name": "default",
|
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
}
|
}
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||||
file_path="",
|
chat.call()
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||||
local_persist=False,
|
|
||||||
vector_store_config=vector_store_config,
|
|
||||||
)
|
|
||||||
query = state.messages[-2][1]
|
|
||||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
|
||||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
|
||||||
state.messages[-2][1] = query
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
|
||||||
|
|
||||||
if mode == conversation_types["custome"] and not db_selector:
|
|
||||||
print("vector store name: ", vector_store_name["vs_name"])
|
|
||||||
vector_store_config = {
|
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
|
||||||
"text_field": "content",
|
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
}
|
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
|
||||||
file_path="",
|
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
|
||||||
local_persist=False,
|
|
||||||
vector_store_config=vector_store_config,
|
|
||||||
)
|
|
||||||
query = state.messages[-2][1]
|
|
||||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
|
||||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
|
||||||
|
|
||||||
state.messages[-2][1] = query
|
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
|
||||||
|
|
||||||
# Make requests
|
|
||||||
payload = {
|
|
||||||
"model": model_name,
|
|
||||||
"prompt": prompt,
|
|
||||||
"temperature": float(temperature),
|
|
||||||
"max_new_tokens": int(max_new_tokens),
|
|
||||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
|
||||||
}
|
|
||||||
logger.info(f"Requert: \n{payload}")
|
|
||||||
|
|
||||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
|
||||||
response = requests.post(
|
|
||||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=120,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(response.json())
|
|
||||||
print(str(response))
|
|
||||||
try:
|
|
||||||
text = response.text.strip()
|
|
||||||
text = text.rstrip()
|
|
||||||
respObj = json.loads(text)
|
|
||||||
|
|
||||||
xx = respObj["response"]
|
|
||||||
xx = xx.strip(b"\x00".decode())
|
|
||||||
respObj_ex = json.loads(xx)
|
|
||||||
if respObj_ex["error_code"] == 0:
|
|
||||||
ai_response = None
|
|
||||||
all_text = respObj_ex["text"]
|
|
||||||
### 解析返回文本,获取AI回复部分
|
|
||||||
tmpResp = all_text.split(state.sep)
|
|
||||||
last_index = -1
|
|
||||||
for i in range(len(tmpResp)):
|
|
||||||
if tmpResp[i].find("ASSISTANT:") != -1:
|
|
||||||
last_index = i
|
|
||||||
ai_response = tmpResp[last_index]
|
|
||||||
ai_response = ai_response.replace("ASSISTANT:", "")
|
|
||||||
ai_response = ai_response.replace("\n", "")
|
|
||||||
ai_response = ai_response.replace("\_", "_")
|
|
||||||
|
|
||||||
print(ai_response)
|
|
||||||
if ai_response == None:
|
|
||||||
state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
|
||||||
else:
|
|
||||||
plugin_resp = execute_ai_response_json(
|
|
||||||
auto_prompt.prompt_generator, ai_response
|
|
||||||
)
|
|
||||||
cfg.set_last_plugin_return(plugin_resp)
|
|
||||||
print(plugin_resp)
|
|
||||||
state.messages[-1][-1] = (
|
|
||||||
"Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
|
|
||||||
)
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
|
||||||
except NotCommands as e:
|
|
||||||
print("命令执行:" + e.message)
|
|
||||||
state.messages[-1][-1] = (
|
|
||||||
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
|
|
||||||
)
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
|
||||||
else:
|
|
||||||
# 流式输出
|
|
||||||
state.messages[-1][-1] = "▌"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Stream output
|
|
||||||
response = requests.post(
|
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
stream=True,
|
|
||||||
timeout=20,
|
|
||||||
)
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode())
|
|
||||||
|
|
||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
|
||||||
"""
|
|
||||||
if data["error_code"] == 0:
|
|
||||||
if "vicuna" in CFG.LLM_MODEL:
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
else:
|
|
||||||
output = data["text"].strip()
|
|
||||||
|
|
||||||
output = post_process_code(output)
|
|
||||||
state.messages[-1][-1] = output + "▌"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
|
||||||
else:
|
|
||||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
|
||||||
state.messages[-1][-1] = output
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
# 记录运行日志
|
else:
|
||||||
finish_tstamp = time.time()
|
dbname = db_selector
|
||||||
logger.info(f"{output}")
|
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||||
|
if state.skip_next:
|
||||||
|
# This generate call is skipped due to invalid inputs
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
return
|
||||||
|
|
||||||
with open(get_conv_log_filename(), "a") as fout:
|
cfg = Config()
|
||||||
data = {
|
auto_prompt = AutoModePrompt()
|
||||||
"tstamp": round(finish_tstamp, 4),
|
auto_prompt.command_registry = cfg.command_registry
|
||||||
"type": "chat",
|
|
||||||
"model": model_name,
|
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||||
"start": round(start_tstamp, 4),
|
if len(state.messages) == state.offset + 2:
|
||||||
"finish": round(start_tstamp, 4),
|
query = state.messages[-2][1]
|
||||||
"state": state.dict(),
|
# 第一轮对话需要加入提示Prompt
|
||||||
"ip": request.client.host,
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
|
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||||
|
system_prompt = auto_prompt.construct_first_prompt(
|
||||||
|
fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname)
|
||||||
|
)
|
||||||
|
logger.info("[TEST]:" + system_prompt)
|
||||||
|
template_name = "auto_dbgpt_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
|
new_state.append_message(role="USER", message=system_prompt)
|
||||||
|
# new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
else:
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
|
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||||
|
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||||
|
if db_selector:
|
||||||
|
new_state.append_message(
|
||||||
|
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||||
|
)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
else:
|
||||||
|
new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
|
||||||
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
|
state = new_state
|
||||||
|
else:
|
||||||
|
### 后续对话
|
||||||
|
query = state.messages[-2][1]
|
||||||
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
if mode == conversation_types["custome"]:
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
|
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||||
|
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||||
|
if db_selector:
|
||||||
|
new_state.append_message(
|
||||||
|
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||||
|
)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
else:
|
||||||
|
new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
|
## 获取最后一次插件的返回
|
||||||
|
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
||||||
|
state.messages[0][0] = ""
|
||||||
|
state.messages[0][1] = ""
|
||||||
|
state.messages[-2][1] = follow_up_prompt
|
||||||
|
prompt = state.get_prompt()
|
||||||
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": "default",
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(data) + "\n")
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
file_path="",
|
||||||
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
|
local_persist=False,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
query = state.messages[-2][1]
|
||||||
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
|
state.messages[-2][1] = query
|
||||||
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
|
if mode == conversation_types["custome"] and not db_selector:
|
||||||
|
print("vector store name: ", vector_store_name["vs_name"])
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
|
"text_field": "content",
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
}
|
||||||
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
file_path="",
|
||||||
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
|
local_persist=False,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
query = state.messages[-2][1]
|
||||||
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
|
|
||||||
|
state.messages[-2][1] = query
|
||||||
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
|
# Make requests
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": float(temperature),
|
||||||
|
"max_new_tokens": int(max_new_tokens),
|
||||||
|
"stop": state.sep
|
||||||
|
if state.sep_style == SeparatorStyle.SINGLE
|
||||||
|
else state.sep2,
|
||||||
|
}
|
||||||
|
logger.info(f"Requert: \n{payload}")
|
||||||
|
|
||||||
|
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.json())
|
||||||
|
print(str(response))
|
||||||
|
try:
|
||||||
|
text = response.text.strip()
|
||||||
|
text = text.rstrip()
|
||||||
|
respObj = json.loads(text)
|
||||||
|
|
||||||
|
xx = respObj["response"]
|
||||||
|
xx = xx.strip(b"\x00".decode())
|
||||||
|
respObj_ex = json.loads(xx)
|
||||||
|
if respObj_ex["error_code"] == 0:
|
||||||
|
ai_response = None
|
||||||
|
all_text = respObj_ex["text"]
|
||||||
|
### 解析返回文本,获取AI回复部分
|
||||||
|
tmpResp = all_text.split(state.sep)
|
||||||
|
last_index = -1
|
||||||
|
for i in range(len(tmpResp)):
|
||||||
|
if tmpResp[i].find("ASSISTANT:") != -1:
|
||||||
|
last_index = i
|
||||||
|
ai_response = tmpResp[last_index]
|
||||||
|
ai_response = ai_response.replace("ASSISTANT:", "")
|
||||||
|
ai_response = ai_response.replace("\n", "")
|
||||||
|
ai_response = ai_response.replace("\_", "_")
|
||||||
|
|
||||||
|
print(ai_response)
|
||||||
|
if ai_response == None:
|
||||||
|
state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
else:
|
||||||
|
plugin_resp = execute_ai_response_json(
|
||||||
|
auto_prompt.prompt_generator, ai_response
|
||||||
|
)
|
||||||
|
cfg.set_last_plugin_return(plugin_resp)
|
||||||
|
print(plugin_resp)
|
||||||
|
state.messages[-1][-1] = (
|
||||||
|
"Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
|
||||||
|
)
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
except NotCommands as e:
|
||||||
|
print("命令执行:" + e.message)
|
||||||
|
state.messages[-1][-1] = (
|
||||||
|
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
|
||||||
|
)
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||||
|
else:
|
||||||
|
# 流式输出
|
||||||
|
state.messages[-1][-1] = "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Stream output
|
||||||
|
response = requests.post(
|
||||||
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
|
"""
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
output = post_process_code(output)
|
||||||
|
state.messages[-1][-1] = output + "▌"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
else:
|
||||||
|
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||||
|
state.messages[-1][-1] = output
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
|
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
|
return
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (
|
||||||
|
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||||
|
return
|
||||||
|
|
||||||
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
# 记录运行日志
|
||||||
|
finish_tstamp = time.time()
|
||||||
|
logger.info(f"{output}")
|
||||||
|
|
||||||
|
with open(get_conv_log_filename(), "a") as fout:
|
||||||
|
data = {
|
||||||
|
"tstamp": round(finish_tstamp, 4),
|
||||||
|
"type": "chat",
|
||||||
|
"model": model_name,
|
||||||
|
"start": round(start_tstamp, 4),
|
||||||
|
"finish": round(start_tstamp, 4),
|
||||||
|
"state": state.dict(),
|
||||||
|
"ip": request.client.host,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(data) + "\n")
|
||||||
|
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
@@ -685,7 +715,8 @@ if __name__ == "__main__":
|
|||||||
# 配置初始化
|
# 配置初始化
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
# dbs = get_database_list()
|
dbs = cfg.local_db.get_database_list()
|
||||||
|
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
|
Reference in New Issue
Block a user