add plugin mode

This commit is contained in:
yhjun1026 2023-05-30 10:45:12 +08:00
parent 20edf6daaa
commit 9511cdb10a
5 changed files with 84 additions and 185 deletions

View File

@ -1,52 +0,0 @@
from typing import Any, Callable, Dict, List, Optional
class PromptGenerator:
"""
generating custom prompt strings based on constraints
Compatible with AutoGpt Plugin;
"""
def __init__(self) -> None:
"""
Initialize the PromptGenerator object with empty lists of constraints,
commands, resources, and performance evaluations.
"""
self.constraints = []
self.commands = []
self.resources = []
self.performance_evaluation = []
self.goals = []
self.command_registry = None
self.name = "Bob"
self.role = "AI"
self.response_format = None
def add_command(
self,
command_label: str,
command_name: str,
args=None,
function: Optional[Callable] = None,
) -> None:
"""
Add a command to the commands list with a label, name, and optional arguments.
GB-GPT and Auto-GPT plugin registration command.
Args:
command_label (str): The label of the command.
command_name (str): The name of the command.
args (dict, optional): A dictionary containing argument names and their
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
"""
if args is None:
args = {}
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
command = {
"label": command_label,
"name": command_name,
"args": command_args,
"function": function,
}
self.commands.append(command)

View File

@ -42,6 +42,7 @@ from pilot.scene.chat_db.out_parser import DbChatOutputParser
CFG = Config()
class ChatWithDb(BaseChat):
chat_scene: str = ChatScene.ChatWithDb.value
@ -49,7 +50,7 @@ class ChatWithDb(BaseChat):
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input)
super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input)
if not db_name:
raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!")
self.db_name = db_name
@ -60,17 +61,16 @@ class ChatWithDb(BaseChat):
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
}
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect)
}
return input_values
def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql)
# def call(self) -> str:
# input_values = {
# "input": self.current_user_input,
@ -176,9 +176,6 @@ class ChatWithDb(BaseChat):
return ret
@property
def chat_type(self) -> str:
return ChatScene.ChatExecution.value

View File

@ -11,6 +11,8 @@ from pilot.configs.config import Config
from pilot.commands.command import execute_command
from pilot.prompts.generator import PluginPromptGenerator
from pilot.scene.chat_execution.prompt import chat_plugin_prompt
CFG = Config()
class ChatWithPlugin(BaseChat):
@ -18,15 +20,19 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator:PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None):
super().__init__(chat_mode, chat_session_id, current_user_input)
def __init__(self, chat_session_id, user_input, plugin_selector:str=None):
super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = self.command_registry
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令
self.select_plugin = select_plugin
self.select_plugin = plugin_selector
if self.select_plugin:
for plugin in CFG.plugins:
if plugin.
if plugin._name == plugin_selector :
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator)
else:
for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt():
@ -39,7 +45,7 @@ class ChatWithPlugin(BaseChat):
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
"constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)),
"commands_infos": self.plugins_prompt_generator.generate_commands_string()
}
return input_values
@ -48,101 +54,12 @@ class ChatWithPlugin(BaseChat):
## plugin command run
return execute_command(str(prompt_response), self.plugins_prompt_generator)
# def call(self):
# input_values = {
# "input": self.current_user_input,
# "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints),
# "commands_infos": self.__get_comnands_promp_info()
# }
#
# ### Chat sequence advance
# self.current_message.chat_order = len(self.history_message) + 1
# self.current_message.add_user_message(self.current_user_input)
# self.current_message.start_date = datetime.datetime.now()
# # TODO
# self.current_message.tokens = 0
#
# current_prompt = self.prompt_template.format(**input_values)
#
# ### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
# if self.history_message:
# ## TODO 带历史对话记录的场景需要确定切换库后怎么处理
# logger.info(
# f"There are already {len(self.history_message)} rounds of conversations!"
# )
#
# self.current_message.add_system_message(current_prompt)
#
# payload = {
# "model": self.llm_model,
# "prompt": self.generate_llm_text(),
# "temperature": float(self.temperature),
# "max_new_tokens": int(self.max_new_tokens),
# "stop": self.prompt_template.sep,
# }
# logger.info(f"Requert: \n{payload}")
# ai_response_text = ""
# try:
# ### 走非流式的模型服务接口
#
# response = requests.post(
# urljoin(CFG.MODEL_SERVER, "generate"),
# headers=headers,
# json=payload,
# timeout=120,
# )
# ai_response_text = (
# self.prompt_template.output_parser.parse_model_server_out(response)
# )
# self.current_message.add_ai_message(ai_response_text)
# prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
#
#
# ## plugin command run
# result = execute_command(prompt_define_response, self.plugins_prompt_generator)
#
# if hasattr(prompt_define_response, "thoughts"):
# if prompt_define_response.thoughts.get("speak"):
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts.get("speak"), result
# )
# )
# elif prompt_define_response.thoughts.get("reasoning"):
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts.get("reasoning"), result
# )
# )
# else:
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response.thoughts, result
# )
# )
# else:
# self.current_message.add_view_message(
# self.prompt_template.output_parser.parse_view_response(
# prompt_define_response, result
# )
# )
#
# except Exception as e:
# print(traceback.format_exc())
# logger.error("model response parase faild" + str(e))
# self.current_message.add_view_message(
# f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
# )
# ### 对话记录存储
# self.memory.append(self.current_message)
def chat_show(self):
super().chat_show()
def __list_to_prompt_str(list: List) -> str:
if not list:
def __list_to_prompt_str(self, list: List) -> str:
if list:
separator = '\n'
return separator.join(list)
else:

View File

@ -52,7 +52,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False
chat_plugin_prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "table_info", "dialect", "top_k", "response"],
input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,

View File

@ -103,6 +103,11 @@ def gen_sqlgen_conversation(dbname):
return f"数据库{dbname}的Schema信息如下: {message}\n"
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:
plugins_infos.update({f"{plugin._name}】=>{plugin._description}": plugin._name})
return plugins_infos
get_window_url_params = """
@ -188,26 +193,27 @@ def post_process_code(code):
return code
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
if mode == conversation_types["default_knownledge"] and not db_selector:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"] and not db_selector:
return ChatScene.ChatNewKnowledge
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
if "插件模式" == selected:
return ChatScene.ChatExecution
elif "知识问答" == selected:
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]:
return ChatScene.ChatNewKnowledge
else:
return ChatScene.ChatNormal
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
return ChatScene.ChatNormal
def http_bot(
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
):
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
@ -216,6 +222,17 @@ def http_bot(
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"current_user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
elif ChatScene.ChatExecution == scene:
logger.info("插件模式对话走新的模式!")
chat_param = {
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
@ -362,8 +379,8 @@ def http_bot(
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -396,6 +413,11 @@ def change_tab():
autogpt = True
def change_func(xx):
print("123")
print(str(xx))
def build_single_model_ui():
notice_markdown = """
# DB-GPT
@ -430,11 +452,18 @@ def build_single_model_ui():
label="最大输出Token数",
)
tabs = gr.Tabs()
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return evt.value
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
tabs.select(on_select, None, selected)
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
tab_sql.select(on_select, None, None)
with tab_sql:
print("tab_sql in...")
# TODO A selector to choose database
@ -452,18 +481,26 @@ def build_single_model_ui():
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label="请选择插件",
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
value="datadance-ddl-excutor",
choices=list(plugins_select_info().keys()),
value="",
interactive=True,
show_label=True,
type="value"
).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
@ -517,7 +554,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@ -526,7 +563,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@ -534,7 +571,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
vs_add.click(
@ -557,10 +594,10 @@ def build_single_model_ui():
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(