mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-16 23:37:52 +00:00
add plugin mode
This commit is contained in:
parent
20edf6daaa
commit
9511cdb10a
@ -1,52 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""
|
||||
generating custom prompt strings based on constraints;
|
||||
Compatible with AutoGpt Plugin;
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
self.constraints = []
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.command_registry = None
|
||||
self.name = "Bob"
|
||||
self.role = "AI"
|
||||
self.response_format = None
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
GB-GPT and Auto-GPT plugin registration command.
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
self.commands.append(command)
|
@ -42,6 +42,7 @@ from pilot.scene.chat_db.out_parser import DbChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithDb(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatWithDb.value
|
||||
|
||||
@ -49,7 +50,7 @@ class ChatWithDb(BaseChat):
|
||||
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
|
||||
super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input)
|
||||
if not db_name:
|
||||
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
|
||||
self.db_name = db_name
|
||||
@ -70,7 +71,6 @@ class ChatWithDb(BaseChat):
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
return self.database.run(self.db_connect, prompt_response.sql)
|
||||
|
||||
|
||||
# def call(self) -> str:
|
||||
# input_values = {
|
||||
# "input": self.current_user_input,
|
||||
@ -176,9 +176,6 @@ class ChatWithDb(BaseChat):
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatExecution.value
|
||||
|
@ -11,6 +11,8 @@ from pilot.configs.config import Config
|
||||
from pilot.commands.command import execute_command
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
|
||||
from pilot.scene.chat_execution.prompt import chat_plugin_prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
class ChatWithPlugin(BaseChat):
|
||||
@ -18,15 +20,19 @@ class ChatWithPlugin(BaseChat):
|
||||
plugins_prompt_generator:PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None):
|
||||
super().__init__(chat_mode, chat_session_id, current_user_input)
|
||||
def __init__(self, chat_session_id, user_input, plugin_selector:str=None):
|
||||
super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = self.command_registry
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# 加载插件中可用命令
|
||||
self.select_plugin = select_plugin
|
||||
self.select_plugin = plugin_selector
|
||||
if self.select_plugin:
|
||||
for plugin in CFG.plugins:
|
||||
if plugin.
|
||||
if plugin._name == plugin_selector :
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
|
||||
|
||||
else:
|
||||
for plugin in CFG.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
@ -39,7 +45,7 @@ class ChatWithPlugin(BaseChat):
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
|
||||
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
|
||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
|
||||
}
|
||||
return input_values
|
||||
@ -48,101 +54,12 @@ class ChatWithPlugin(BaseChat):
|
||||
## plugin command run
|
||||
return execute_command(str(prompt_response), self.plugins_prompt_generator)
|
||||
|
||||
|
||||
# def call(self):
|
||||
# input_values = {
|
||||
# "input": self.current_user_input,
|
||||
# "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
|
||||
# "commands_infos": self.__get_comnands_promp_info()
|
||||
# }
|
||||
#
|
||||
# ### Chat sequence advance
|
||||
# self.current_message.chat_order = len(self.history_message) + 1
|
||||
# self.current_message.add_user_message(self.current_user_input)
|
||||
# self.current_message.start_date = datetime.datetime.now()
|
||||
# # TODO
|
||||
# self.current_message.tokens = 0
|
||||
#
|
||||
# current_prompt = self.prompt_template.format(**input_values)
|
||||
#
|
||||
# ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||
# if self.history_message:
|
||||
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||
# logger.info(
|
||||
# f"There are already {len(self.history_message)} rounds of conversations!"
|
||||
# )
|
||||
#
|
||||
# self.current_message.add_system_message(current_prompt)
|
||||
#
|
||||
# payload = {
|
||||
# "model": self.llm_model,
|
||||
# "prompt": self.generate_llm_text(),
|
||||
# "temperature": float(self.temperature),
|
||||
# "max_new_tokens": int(self.max_new_tokens),
|
||||
# "stop": self.prompt_template.sep,
|
||||
# }
|
||||
# logger.info(f"Requert: \n{payload}")
|
||||
# ai_response_text = ""
|
||||
# try:
|
||||
# ### 走非流式的模型服务接口
|
||||
#
|
||||
# response = requests.post(
|
||||
# urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
# headers=headers,
|
||||
# json=payload,
|
||||
# timeout=120,
|
||||
# )
|
||||
# ai_response_text = (
|
||||
# self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
# )
|
||||
# self.current_message.add_ai_message(ai_response_text)
|
||||
# prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
#
|
||||
#
|
||||
# ## plugin command run
|
||||
# result = execute_command(prompt_define_response, self.plugins_prompt_generator)
|
||||
#
|
||||
# if hasattr(prompt_define_response, "thoughts"):
|
||||
# if prompt_define_response.thoughts.get("speak"):
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts.get("speak"), result
|
||||
# )
|
||||
# )
|
||||
# elif prompt_define_response.thoughts.get("reasoning"):
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts.get("reasoning"), result
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response.thoughts, result
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# self.current_message.add_view_message(
|
||||
# self.prompt_template.output_parser.parse_view_response(
|
||||
# prompt_define_response, result
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(traceback.format_exc())
|
||||
# logger.error("model response parase faild!" + str(e))
|
||||
# self.current_message.add_view_message(
|
||||
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
# )
|
||||
# ### 对话记录存储
|
||||
# self.memory.append(self.current_message)
|
||||
|
||||
def chat_show(self):
|
||||
super().chat_show()
|
||||
|
||||
|
||||
def __list_to_prompt_str(list: List) -> str:
|
||||
if not list:
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
if list:
|
||||
separator = '\n'
|
||||
return separator.join(list)
|
||||
else:
|
||||
|
@ -52,7 +52,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
chat_plugin_prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExecution.value,
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
|
@ -103,6 +103,11 @@ def gen_sqlgen_conversation(dbname):
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
@ -188,26 +193,27 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
||||
if "插件模式" == selected:
|
||||
return ChatScene.ChatExecution
|
||||
elif "知识问答" == selected:
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"] and not db_selector:
|
||||
elif mode == conversation_types["custome"]:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
else:
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
return ChatScene.ChatExecution
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
@ -216,6 +222,17 @@ def http_bot(
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"current_user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
elif ChatScene.ChatExecution == scene:
|
||||
logger.info("插件模式对话走新的模式!")
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"plugin_selector": plugin_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
@ -396,6 +413,11 @@ def change_tab():
|
||||
autogpt = True
|
||||
|
||||
|
||||
def change_func(xx):
|
||||
print("123")
|
||||
print(str(xx))
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
@ -430,11 +452,18 @@ def build_single_model_ui():
|
||||
label="最大输出Token数",
|
||||
)
|
||||
|
||||
|
||||
tabs = gr.Tabs()
|
||||
|
||||
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
return evt.value
|
||||
|
||||
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
tabs.select(on_select, None, selected)
|
||||
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
tab_sql.select(on_select, None, None)
|
||||
with tab_sql:
|
||||
print("tab_sql in...")
|
||||
# TODO A selector to choose database
|
||||
@ -452,18 +481,26 @@ def build_single_model_ui():
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
|
||||
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
||||
# tab_plugin.select(change_func)
|
||||
with tab_plugin:
|
||||
print("tab_plugin in...")
|
||||
with gr.Row(elem_id="plugin_selector"):
|
||||
# TODO
|
||||
plugin_selector = gr.Dropdown(
|
||||
label="请选择插件",
|
||||
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
|
||||
value="datadance-ddl-excutor",
|
||||
choices=list(plugins_select_info().keys()),
|
||||
value="",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
type="value"
|
||||
).style(container=False)
|
||||
|
||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
return plugins_select_info().get(evt.value)
|
||||
|
||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
@ -517,7 +554,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@ -526,7 +563,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@ -534,7 +571,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
|
Loading…
Reference in New Issue
Block a user