mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-21 17:54:58 +00:00
add plugin mode
This commit is contained in:
parent
dd5fc529e2
commit
5150cfcf55
0
pilot/commands/built_in/__init__.py
Normal file
0
pilot/commands/built_in/__init__.py
Normal file
@ -1,29 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.prompts.generator import PromptGenerator
|
|
||||||
from pilot.prompts.prompt import build_default_prompt_generator
|
|
||||||
|
|
||||||
|
|
||||||
class CommandsLoad:
|
|
||||||
"""
|
|
||||||
Load Plugins Commands Info , help build system prompt!
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.command_registry = None
|
|
||||||
|
|
||||||
def getCommandInfos(
|
|
||||||
self, prompt_generator: Optional[PromptGenerator] = None
|
|
||||||
) -> str:
|
|
||||||
cfg = Config()
|
|
||||||
if prompt_generator is None:
|
|
||||||
prompt_generator = build_default_prompt_generator()
|
|
||||||
for plugin in cfg.plugins:
|
|
||||||
if not plugin.can_handle_post_prompt():
|
|
||||||
continue
|
|
||||||
prompt_generator = plugin.post_prompt(prompt_generator)
|
|
||||||
self.prompt_generator = prompt_generator
|
|
||||||
command_infos = ""
|
|
||||||
command_infos += f"\n\n{prompt_generator.commands()}"
|
|
||||||
return command_infos
|
|
@ -263,6 +263,14 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
|
|||||||
# """
|
# """
|
||||||
default_conversation = conv_one_shot
|
default_conversation = conv_one_shot
|
||||||
|
|
||||||
|
|
||||||
|
chat_mode_title = {
|
||||||
|
"sql_generate_diagnostics": get_lang_text("sql_analysis_and_diagnosis"),
|
||||||
|
"chat_use_plugin": get_lang_text("chat_use_plugin"),
|
||||||
|
"knowledge_qa": get_lang_text("knowledge_qa"),
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
conversation_sql_mode = {
|
conversation_sql_mode = {
|
||||||
"auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"),
|
"auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"),
|
||||||
"dont_execute_ai_response": get_lang_text("sql_generate_mode_none"),
|
"dont_execute_ai_response": get_lang_text("sql_generate_mode_none"),
|
||||||
@ -274,7 +282,7 @@ conversation_types = {
|
|||||||
"knowledge_qa_type_default_knowledge_base_dialogue"
|
"knowledge_qa_type_default_knowledge_base_dialogue"
|
||||||
),
|
),
|
||||||
"custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"),
|
"custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"),
|
||||||
"auto_execute_plugin": get_lang_text("dialogue_use_plugin"),
|
"url": get_lang_text("knowledge_qa_type_url_knowledge_dialogue"),
|
||||||
}
|
}
|
||||||
|
|
||||||
conv_templates = {
|
conv_templates = {
|
||||||
|
@ -14,17 +14,22 @@ lang_dicts = {
|
|||||||
"knowledge_qa_type_llm_native_dialogue": "LLM原生对话",
|
"knowledge_qa_type_llm_native_dialogue": "LLM原生对话",
|
||||||
"knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话",
|
"knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话",
|
||||||
"knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话",
|
"knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话",
|
||||||
"dialogue_use_plugin": "对话使用插件",
|
"knowledge_qa_type_url_knowledge_dialogue": "URL网页知识对话",
|
||||||
"create_knowledge_base": "新建知识库",
|
"create_knowledge_base": "新建知识库",
|
||||||
"sql_schema_info": "数据库{}的Schema信息如下: {}\n",
|
"sql_schema_info": "数据库{}的Schema信息如下: {}\n",
|
||||||
"current_dialogue_mode": "当前对话模式",
|
"current_dialogue_mode": "当前对话模式",
|
||||||
"database_smart_assistant": "数据库智能助手",
|
"database_smart_assistant": "数据库智能助手",
|
||||||
"sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力",
|
"sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力",
|
||||||
"knowledge_qa": "知识问答",
|
"knowledge_qa": "知识问答",
|
||||||
|
"chat_use_plugin": "插件模式",
|
||||||
|
"dialogue_use_plugin": "对话使用插件",
|
||||||
|
"select_plugin": "选择插件",
|
||||||
"configure_knowledge_base": "配置知识库",
|
"configure_knowledge_base": "配置知识库",
|
||||||
"new_klg_name": "新知识库名称",
|
"new_klg_name": "新知识库名称",
|
||||||
|
"url_input_label": "输入网页地址",
|
||||||
"add_as_new_klg": "添加为新知识库",
|
"add_as_new_klg": "添加为新知识库",
|
||||||
"add_file_to_klg": "向知识库中添加文件",
|
"add_file_to_klg": "向知识库中添加文件",
|
||||||
|
|
||||||
"upload_file": "上传文件",
|
"upload_file": "上传文件",
|
||||||
"add_file": "添加文件",
|
"add_file": "添加文件",
|
||||||
"upload_and_load_to_klg": "上传并加载到知识库",
|
"upload_and_load_to_klg": "上传并加载到知识库",
|
||||||
@ -47,14 +52,18 @@ lang_dicts = {
|
|||||||
"knowledge_qa_type_llm_native_dialogue": "LLM native dialogue",
|
"knowledge_qa_type_llm_native_dialogue": "LLM native dialogue",
|
||||||
"knowledge_qa_type_default_knowledge_base_dialogue": "Default documents",
|
"knowledge_qa_type_default_knowledge_base_dialogue": "Default documents",
|
||||||
"knowledge_qa_type_add_knowledge_base_dialogue": "Added documents",
|
"knowledge_qa_type_add_knowledge_base_dialogue": "Added documents",
|
||||||
|
"knowledge_qa_type_url_knowledge_dialogue": "Chat with url",
|
||||||
"dialogue_use_plugin": "Dialogue Extension",
|
"dialogue_use_plugin": "Dialogue Extension",
|
||||||
"create_knowledge_base": "Create Knowledge Base",
|
"create_knowledge_base": "Create Knowledge Base",
|
||||||
"sql_schema_info": "the schema information of database {}: {}\n",
|
"sql_schema_info": "the schema information of database {}: {}\n",
|
||||||
"current_dialogue_mode": "Current dialogue mode",
|
"current_dialogue_mode": "Current dialogue mode",
|
||||||
"database_smart_assistant": "Database smart assistant",
|
"database_smart_assistant": "Database smart assistant",
|
||||||
"sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn",
|
"sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn",
|
||||||
|
"chat_use_plugin": "Plugin Mode",
|
||||||
|
"select_plugin": "Select Plugin",
|
||||||
"knowledge_qa": "Documents QA",
|
"knowledge_qa": "Documents QA",
|
||||||
"configure_knowledge_base": "Configure Documents",
|
"configure_knowledge_base": "Configure Documents",
|
||||||
|
"url_input_label": "Please input url",
|
||||||
"new_klg_name": "New document name",
|
"new_klg_name": "New document name",
|
||||||
"add_as_new_klg": "Add as new documents",
|
"add_as_new_klg": "Add as new documents",
|
||||||
"add_file_to_klg": "Add file to documents",
|
"add_file_to_klg": "Add file to documents",
|
||||||
|
@ -18,11 +18,14 @@ import re
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.prompts.base import PromptValue
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(ABC):
|
class BaseOutputParser(ABC):
|
||||||
"""Class to parse the output of an LLM call.
|
"""Class to parse the output of an LLM call.
|
||||||
|
|
||||||
@ -33,9 +36,39 @@ class BaseOutputParser(ABC):
|
|||||||
self.sep = sep
|
self.sep = sep
|
||||||
self.is_stream_out = is_stream_out
|
self.is_stream_out = is_stream_out
|
||||||
|
|
||||||
|
def __post_process_code(code):
|
||||||
|
sep = "\n```"
|
||||||
|
if sep in code:
|
||||||
|
blocks = code.split(sep)
|
||||||
|
if len(blocks) % 2 == 1:
|
||||||
|
for i in range(1, len(blocks), 2):
|
||||||
|
blocks[i] = blocks[i].replace("\\_", "_")
|
||||||
|
code = sep.join(blocks)
|
||||||
|
return code
|
||||||
|
|
||||||
# TODO 后续和模型绑定
|
# TODO 后续和模型绑定
|
||||||
def _parse_model_stream_resp(self, response, sep: str):
|
def _parse_model_stream_resp(self, response, sep: str):
|
||||||
pass
|
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode())
|
||||||
|
|
||||||
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
|
"""
|
||||||
|
if data["error_code"] == 0:
|
||||||
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
|
|
||||||
|
output = data["text"].strip()
|
||||||
|
else:
|
||||||
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
output = self.__post_process_code(output)
|
||||||
|
yield output
|
||||||
|
else:
|
||||||
|
output = (
|
||||||
|
data["text"] + f" (error_code: {data['error_code']})"
|
||||||
|
)
|
||||||
|
yield output
|
||||||
|
|
||||||
def _parse_model_nostream_resp(self, response, sep: str):
|
def _parse_model_nostream_resp(self, response, sep: str):
|
||||||
text = response.text.strip()
|
text = response.text.strip()
|
||||||
@ -64,7 +97,7 @@ class BaseOutputParser(ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
def parse_model_server_out(self, response) -> str:
|
def parse_model_server_out(self, response):
|
||||||
"""
|
"""
|
||||||
parse the model server http response
|
parse the model server http response
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import datetime
|
import datetime
|
||||||
import traceback
|
import traceback
|
||||||
|
import json
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -41,6 +42,7 @@ headers = {"User-Agent": "dbgpt Client"}
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChat(ABC):
|
class BaseChat(ABC):
|
||||||
chat_scene: str = None
|
chat_scene: str = None
|
||||||
llm_model: Any = None
|
llm_model: Any = None
|
||||||
@ -89,8 +91,7 @@ class BaseChat(ABC):
|
|||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def call(self, show_fn, state):
|
||||||
def call(self):
|
|
||||||
input_values = self.generate_input_values()
|
input_values = self.generate_input_values()
|
||||||
|
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
@ -164,6 +165,7 @@ class BaseChat(ABC):
|
|||||||
prompt_define_response, result
|
prompt_define_response, result
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
show_fn(state, self.current_ai_response())
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
@ -171,9 +173,14 @@ class BaseChat(ABC):
|
|||||||
json=payload,
|
json=payload,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
#TODO
|
show_fn(state, "▌")
|
||||||
|
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||||
|
show_info =""
|
||||||
|
for resp_text_trunck in ai_response_text:
|
||||||
|
show_info = resp_text_trunck
|
||||||
|
show_fn(state, resp_text_trunck + "▌")
|
||||||
|
|
||||||
|
self.current_message.add_ai_message(show_info)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -181,9 +188,11 @@ class BaseChat(ABC):
|
|||||||
self.current_message.add_view_message(
|
self.current_message.add_view_message(
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||||
)
|
)
|
||||||
|
show_fn(state, self.current_ai_response())
|
||||||
### 对话记录存储
|
### 对话记录存储
|
||||||
self.memory.append(self.current_message)
|
self.memory.append(self.current_message)
|
||||||
|
|
||||||
|
|
||||||
def generate_llm_text(self) -> str:
|
def generate_llm_text(self) -> str:
|
||||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||||
### 线处理历史信息
|
### 线处理历史信息
|
||||||
@ -229,8 +238,6 @@ class BaseChat(ABC):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def chat_show(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 暂时为了兼容前端
|
# 暂时为了兼容前端
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
|
0
pilot/scene/chat_knowledge/custom/__init__.py
Normal file
0
pilot/scene/chat_knowledge/custom/__init__.py
Normal file
0
pilot/scene/chat_knowledge/default/__init__.py
Normal file
0
pilot/scene/chat_knowledge/default/__init__.py
Normal file
0
pilot/scene/chat_knowledge/url/__init__.py
Normal file
0
pilot/scene/chat_knowledge/url/__init__.py
Normal file
@ -37,6 +37,7 @@ from pilot.conversation import (
|
|||||||
conv_templates,
|
conv_templates,
|
||||||
conversation_sql_mode,
|
conversation_sql_mode,
|
||||||
conversation_types,
|
conversation_types,
|
||||||
|
chat_mode_title,
|
||||||
default_conversation,
|
default_conversation,
|
||||||
)
|
)
|
||||||
from pilot.common.plugins import scan_plugins
|
from pilot.common.plugins import scan_plugins
|
||||||
@ -95,6 +96,11 @@ default_knowledge_base_dialogue = get_lang_text(
|
|||||||
add_knowledge_base_dialogue = get_lang_text(
|
add_knowledge_base_dialogue = get_lang_text(
|
||||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
"knowledge_qa_type_add_knowledge_base_dialogue"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
url_knowledge_dialogue = get_lang_text(
|
||||||
|
"knowledge_qa_type_url_knowledge_dialogue"
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_qa_type_list = [
|
knowledge_qa_type_list = [
|
||||||
llm_native_dialogue,
|
llm_native_dialogue,
|
||||||
default_knowledge_base_dialogue,
|
default_knowledge_base_dialogue,
|
||||||
@ -115,7 +121,7 @@ def gen_sqlgen_conversation(dbname):
|
|||||||
db_connect = CFG.local_db.get_session(dbname)
|
db_connect = CFG.local_db.get_session(dbname)
|
||||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||||
for s in schemas:
|
for s in schemas:
|
||||||
message += s["schema_info"] + ";"
|
message += s+ ";"
|
||||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
return get_lang_text("sql_schema_info").format(dbname, message)
|
||||||
|
|
||||||
|
|
||||||
@ -211,9 +217,9 @@ def post_process_code(code):
|
|||||||
|
|
||||||
|
|
||||||
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
||||||
if "插件模式" == selected:
|
if chat_mode_title['chat_use_plugin'] == selected:
|
||||||
return ChatScene.ChatExecution
|
return ChatScene.ChatExecution
|
||||||
elif "知识问答" == selected:
|
elif chat_mode_title['knowledge_qa'] == selected:
|
||||||
if mode == conversation_types["default_knownledge"]:
|
if mode == conversation_types["default_knownledge"]:
|
||||||
return ChatScene.ChatKnowledge
|
return ChatScene.ChatKnowledge
|
||||||
elif mode == conversation_types["custome"]:
|
elif mode == conversation_types["custome"]:
|
||||||
@ -226,37 +232,50 @@ def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
|||||||
|
|
||||||
|
|
||||||
def http_bot(
|
def http_bot(
|
||||||
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request
|
||||||
):
|
):
|
||||||
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
||||||
start_tstamp = time.time()
|
start_tstamp = time.time()
|
||||||
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
|
||||||
print(f"当前对话模式:{scene.value}")
|
print(f"now chat scene:{scene.value}")
|
||||||
model_name = CFG.LLM_MODEL
|
model_name = CFG.LLM_MODEL
|
||||||
|
|
||||||
|
def chatbot_callback(state, message):
|
||||||
|
print(f"chatbot_callback:{message}")
|
||||||
|
state.messages[-1][-1] = f"{message}"
|
||||||
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
if ChatScene.ChatWithDb == scene:
|
if ChatScene.ChatWithDb == scene:
|
||||||
logger.info("基于DB对话走新的模式!")
|
logger.info("chat with db mode use new architecture design!")
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"db_name": db_selector,
|
"db_name": db_selector,
|
||||||
"user_input": state.last_user_input,
|
"user_input": state.last_user_input,
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||||
chat.call()
|
chat.call(show_fn=chatbot_callback, state= state)
|
||||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
elif ChatScene.ChatExecution == scene:
|
elif ChatScene.ChatExecution == scene:
|
||||||
logger.info("插件模式对话走新的模式!")
|
logger.info("plugin mode use new architecture design!")
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": state.conv_id,
|
"chat_session_id": state.conv_id,
|
||||||
"plugin_selector": plugin_selector,
|
"plugin_selector": plugin_selector,
|
||||||
"user_input": state.last_user_input,
|
"user_input": state.last_user_input,
|
||||||
}
|
}
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||||
chat.call()
|
chat.call(chatbot_callback, state)
|
||||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
|
|
||||||
|
# def generate_numbers():
|
||||||
|
# for i in range(10):
|
||||||
|
# time.sleep(0.5)
|
||||||
|
# yield f"Message:{i}"
|
||||||
|
#
|
||||||
|
# def showMessage(message):
|
||||||
|
# return message
|
||||||
|
#
|
||||||
|
# for n in generate_numbers():
|
||||||
|
# state.messages[-1][-1] = n + "▌"
|
||||||
|
# yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
else:
|
else:
|
||||||
|
|
||||||
dbname = db_selector
|
dbname = db_selector
|
||||||
@ -284,30 +303,45 @@ def http_bot(
|
|||||||
|
|
||||||
new_state.conv_id = uuid.uuid4().hex
|
new_state.conv_id = uuid.uuid4().hex
|
||||||
state = new_state
|
state = new_state
|
||||||
|
else:
|
||||||
|
### 后续对话
|
||||||
|
query = state.messages[-2][1]
|
||||||
|
# 第一轮对话需要加入提示Prompt
|
||||||
|
if mode == conversation_types["custome"]:
|
||||||
|
template_name = "conv_one_shot"
|
||||||
|
new_state = conv_templates[template_name].copy()
|
||||||
|
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||||
|
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||||
|
if db_selector:
|
||||||
|
new_state.append_message(
|
||||||
|
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||||
|
)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
else:
|
||||||
|
new_state.append_message(new_state.roles[0], query)
|
||||||
|
new_state.append_message(new_state.roles[1], None)
|
||||||
|
state = new_state
|
||||||
|
|
||||||
prompt = state.get_prompt()
|
prompt = state.get_prompt()
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": "default",
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
}
|
||||||
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
file_path="",
|
||||||
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
|
local_persist=False,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
query = state.messages[-2][1]
|
query = state.messages[-2][1]
|
||||||
knqa = KnownLedgeBaseQA()
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
prompt = state.get_prompt()
|
|
||||||
state.messages[-2][1] = query
|
state.messages[-2][1] = query
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
if mode == conversation_types["custome"] and not db_selector:
|
if mode == conversation_types["custome"] and not db_selector:
|
||||||
persist_dir = os.path.join(
|
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
|
|
||||||
)
|
|
||||||
print("向量数据库持久化地址: ", persist_dir)
|
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
|
||||||
file_path="",
|
|
||||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
|
||||||
vector_store_config={
|
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print("vector store name: ", vector_store_name["vs_name"])
|
print("vector store name: ", vector_store_name["vs_name"])
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
@ -327,6 +361,27 @@ def http_bot(
|
|||||||
state.messages[-2][1] = query
|
state.messages[-2][1] = query
|
||||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
|
if mode == conversation_types["url"] and url_input:
|
||||||
|
print("url: ", url_input)
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": url_input,
|
||||||
|
"text_field": "content",
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
}
|
||||||
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
file_path=url_input,
|
||||||
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
|
local_persist=False,
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
query = state.messages[-2][1]
|
||||||
|
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||||
|
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||||
|
|
||||||
|
state.messages[-2][1] = query
|
||||||
|
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||||
|
|
||||||
# Make requests
|
# Make requests
|
||||||
payload = {
|
payload = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
@ -343,34 +398,6 @@ def http_bot(
|
|||||||
state.messages[-1][-1] = "▌"
|
state.messages[-1][-1] = "▌"
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||||
|
|
||||||
try:
|
|
||||||
# Stream output
|
|
||||||
response = requests.post(
|
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
stream=True,
|
|
||||||
timeout=20,
|
|
||||||
)
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode())
|
|
||||||
if data["error_code"] == 0:
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
output = post_process_code(output)
|
|
||||||
state.messages[-1][-1] = output + "▌"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
|
||||||
else:
|
|
||||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
|
||||||
state.messages[-1][-1] = output
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
# Stream output
|
# Stream output
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -421,16 +448,6 @@ def http_bot(
|
|||||||
enable_btn,
|
enable_btn,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
disable_btn,
|
|
||||||
enable_btn,
|
|
||||||
enable_btn,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
@ -477,15 +494,12 @@ def change_sql_mode(sql_mode):
|
|||||||
|
|
||||||
|
|
||||||
def change_mode(mode):
|
def change_mode(mode):
|
||||||
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
|
if mode in [add_knowledge_base_dialogue]:
|
||||||
return gr.update(visible=False)
|
|
||||||
else:
|
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
|
else:
|
||||||
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def change_tab():
|
|
||||||
autogpt = True
|
|
||||||
|
|
||||||
|
|
||||||
def build_single_model_ui():
|
def build_single_model_ui():
|
||||||
notice_markdown = get_lang_text("db_gpt_introduction")
|
notice_markdown = get_lang_text("db_gpt_introduction")
|
||||||
@ -548,15 +562,14 @@ def build_single_model_ui():
|
|||||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||||
|
|
||||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
|
||||||
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
|
||||||
# tab_plugin.select(change_func)
|
# tab_plugin.select(change_func)
|
||||||
with tab_plugin:
|
with tab_plugin:
|
||||||
print("tab_plugin in...")
|
print("tab_plugin in...")
|
||||||
with gr.Row(elem_id="plugin_selector"):
|
with gr.Row(elem_id="plugin_selector"):
|
||||||
# TODO
|
# TODO
|
||||||
plugin_selector = gr.Dropdown(
|
plugin_selector = gr.Dropdown(
|
||||||
label="请选择插件",
|
label=get_lang_text("select_plugin"),
|
||||||
choices=list(plugins_select_info().keys()),
|
choices=list(plugins_select_info().keys()),
|
||||||
value="",
|
value="",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
@ -578,6 +591,7 @@ def build_single_model_ui():
|
|||||||
llm_native_dialogue,
|
llm_native_dialogue,
|
||||||
default_knowledge_base_dialogue,
|
default_knowledge_base_dialogue,
|
||||||
add_knowledge_base_dialogue,
|
add_knowledge_base_dialogue,
|
||||||
|
url_knowledge_dialogue,
|
||||||
],
|
],
|
||||||
show_label=False,
|
show_label=False,
|
||||||
value=llm_native_dialogue,
|
value=llm_native_dialogue,
|
||||||
@ -586,6 +600,16 @@ def build_single_model_ui():
|
|||||||
get_lang_text("configure_knowledge_base"), open=False
|
get_lang_text("configure_knowledge_base"), open=False
|
||||||
)
|
)
|
||||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||||
|
|
||||||
|
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True)
|
||||||
|
def show_url_input(evt:gr.SelectData):
|
||||||
|
if evt.value == url_knowledge_dialogue:
|
||||||
|
return gr.update(visible=True)
|
||||||
|
else:
|
||||||
|
return gr.update(visible=False)
|
||||||
|
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
|
||||||
|
|
||||||
|
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
vs_name = gr.Textbox(
|
vs_name = gr.Textbox(
|
||||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
||||||
@ -636,7 +660,7 @@ def build_single_model_ui():
|
|||||||
btn_list = [regenerate_btn, clear_btn]
|
btn_list = [regenerate_btn, clear_btn]
|
||||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||||
@ -645,7 +669,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -653,7 +677,7 @@ def build_single_model_ui():
|
|||||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||||
).then(
|
).then(
|
||||||
http_bot,
|
http_bot,
|
||||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
|
||||||
[state, chatbot] + btn_list,
|
[state, chatbot] + btn_list,
|
||||||
)
|
)
|
||||||
vs_add.click(
|
vs_add.click(
|
||||||
@ -760,8 +784,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 加载插件可执行命令
|
# 加载插件可执行命令
|
||||||
command_categories = [
|
command_categories = [
|
||||||
"pilot.commands.audio_text",
|
"pilot.commands.built_in.audio_text",
|
||||||
"pilot.commands.image_gen",
|
"pilot.commands.built_in.image_gen",
|
||||||
]
|
]
|
||||||
# 排除禁用命令
|
# 排除禁用命令
|
||||||
command_categories = [
|
command_categories = [
|
||||||
|
0
pilot/source_embedding/external/__init__.py
vendored
Normal file
0
pilot/source_embedding/external/__init__.py
vendored
Normal file
@ -11,6 +11,7 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
|||||||
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||||
|
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -61,6 +62,12 @@ class KnowledgeEmbedding:
|
|||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
vector_store_config=self.vector_store_config,
|
vector_store_config=self.vector_store_config,
|
||||||
)
|
)
|
||||||
|
elif self.file_type == "url":
|
||||||
|
embedding = URLEmbedding(
|
||||||
|
file_path=self.file_path,
|
||||||
|
model_name=self.model_name,
|
||||||
|
vector_store_config=self.vector_store_config,
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user