diff --git a/pilot/prompts/prompt_generator.py b/pilot/prompts/prompt_generator.py deleted file mode 100644 index 1ec62d5c9..000000000 --- a/pilot/prompts/prompt_generator.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional - - -class PromptGenerator: - """ - generating custom prompt strings based on constraints; - Compatible with AutoGpt Plugin; - """ - - def __init__(self) -> None: - """ - Initialize the PromptGenerator object with empty lists of constraints, - commands, resources, and performance evaluations. - """ - self.constraints = [] - self.commands = [] - self.resources = [] - self.performance_evaluation = [] - self.goals = [] - self.command_registry = None - self.name = "Bob" - self.role = "AI" - self.response_format = None - - def add_command( - self, - command_label: str, - command_name: str, - args=None, - function: Optional[Callable] = None, - ) -> None: - """ - Add a command to the commands list with a label, name, and optional arguments. - GB-GPT and Auto-GPT plugin registration command. - Args: - command_label (str): The label of the command. - command_name (str): The name of the command. - args (dict, optional): A dictionary containing argument names and their - values. Defaults to None. - function (callable, optional): A callable function to be called when - the command is executed. Defaults to None. - """ - if args is None: - args = {} - command_args = {arg_key: arg_value for arg_key, arg_value in args.items()} - command = { - "label": command_label, - "name": command_name, - "args": command_args, - "function": function, - } - self.commands.append(command) diff --git a/pilot/scene/chat_db/chat.py b/pilot/scene/chat_db/chat.py index 37382b7d2..745e9804d 100644 --- a/pilot/scene/chat_db/chat.py +++ b/pilot/scene/chat_db/chat.py @@ -42,6 +42,7 @@ from pilot.scene.chat_db.out_parser import DbChatOutputParser CFG = Config() + class ChatWithDb(BaseChat): chat_scene: str = ChatScene.ChatWithDb.value @@ -49,7 +50,7 @@ class ChatWithDb(BaseChat): def __init__(self, chat_session_id, db_name, user_input): """ """ - super().__init__(ChatScene.ChatWithDb, chat_session_id, user_input) + super().__init__(chat_mode=ChatScene.ChatWithDb, chat_session_id=chat_session_id, current_user_input=user_input) if not db_name: raise ValueError(f"{ChatScene.ChatWithDb.value} mode should chose db!") self.db_name = db_name @@ -60,17 +61,16 @@ class ChatWithDb(BaseChat): def generate_input_values(self): input_values = { - "input": self.current_user_input, - "top_k": str(self.top_k), - "dialect": self.database.dialect, - "table_info": self.database.table_simple_info(self.db_connect) - } + "input": self.current_user_input, + "top_k": str(self.top_k), + "dialect": self.database.dialect, + "table_info": self.database.table_simple_info(self.db_connect) + } return input_values def do_with_prompt_response(self, prompt_response): return self.database.run(self.db_connect, prompt_response.sql) - # def call(self) -> str: # input_values = { # "input": self.current_user_input, @@ -176,9 +176,6 @@ class ChatWithDb(BaseChat): return ret - - - @property def chat_type(self) -> str: return ChatScene.ChatExecution.value diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index a5abadad0..210b2ad77 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -11,6 +11,8 @@ from pilot.configs.config import Config from pilot.commands.command import execute_command from pilot.prompts.generator import PluginPromptGenerator +from pilot.scene.chat_execution.prompt import chat_plugin_prompt + CFG = Config() class ChatWithPlugin(BaseChat): @@ -18,15 +20,19 @@ class ChatWithPlugin(BaseChat): plugins_prompt_generator:PluginPromptGenerator select_plugin: str = None - def __init__(self, chat_mode, chat_session_id, current_user_input, select_plugin:str=None): - super().__init__(chat_mode, chat_session_id, current_user_input) + def __init__(self, chat_session_id, user_input, plugin_selector:str=None): + super().__init__(chat_mode=ChatScene.ChatExecution, chat_session_id=chat_session_id, current_user_input=user_input) self.plugins_prompt_generator = PluginPromptGenerator() - self.plugins_prompt_generator.command_registry = self.command_registry + self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令 - self.select_plugin = select_plugin + self.select_plugin = plugin_selector if self.select_plugin: for plugin in CFG.plugins: - if plugin. + if plugin._name == plugin_selector : + if not plugin.can_handle_post_prompt(): + continue + self.plugins_prompt_generator = plugin.post_prompt(self.plugins_prompt_generator) + else: for plugin in CFG.plugins: if not plugin.can_handle_post_prompt(): @@ -39,7 +45,7 @@ class ChatWithPlugin(BaseChat): def generate_input_values(self): input_values = { "input": self.current_user_input, - "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), + "constraints": self.__list_to_prompt_str(list(self.plugins_prompt_generator.constraints)), "commands_infos": self.plugins_prompt_generator.generate_commands_string() } return input_values @@ -48,101 +54,12 @@ class ChatWithPlugin(BaseChat): ## plugin command run return execute_command(str(prompt_response), self.plugins_prompt_generator) - - # def call(self): - # input_values = { - # "input": self.current_user_input, - # "constraints": self.__list_to_prompt_str(self.plugins_prompt_generator.constraints), - # "commands_infos": self.__get_comnands_promp_info() - # } - # - # ### Chat sequence advance - # self.current_message.chat_order = len(self.history_message) + 1 - # self.current_message.add_user_message(self.current_user_input) - # self.current_message.start_date = datetime.datetime.now() - # # TODO - # self.current_message.tokens = 0 - # - # current_prompt = self.prompt_template.format(**input_values) - # - # ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - # if self.history_message: - # ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - # logger.info( - # f"There are already {len(self.history_message)} rounds of conversations!" - # ) - # - # self.current_message.add_system_message(current_prompt) - # - # payload = { - # "model": self.llm_model, - # "prompt": self.generate_llm_text(), - # "temperature": float(self.temperature), - # "max_new_tokens": int(self.max_new_tokens), - # "stop": self.prompt_template.sep, - # } - # logger.info(f"Requert: \n{payload}") - # ai_response_text = "" - # try: - # ### 走非流式的模型服务接口 - # - # response = requests.post( - # urljoin(CFG.MODEL_SERVER, "generate"), - # headers=headers, - # json=payload, - # timeout=120, - # ) - # ai_response_text = ( - # self.prompt_template.output_parser.parse_model_server_out(response) - # ) - # self.current_message.add_ai_message(ai_response_text) - # prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - # - # - # ## plugin command run - # result = execute_command(prompt_define_response, self.plugins_prompt_generator) - # - # if hasattr(prompt_define_response, "thoughts"): - # if prompt_define_response.thoughts.get("speak"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("speak"), result - # ) - # ) - # elif prompt_define_response.thoughts.get("reasoning"): - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts.get("reasoning"), result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response.thoughts, result - # ) - # ) - # else: - # self.current_message.add_view_message( - # self.prompt_template.output_parser.parse_view_response( - # prompt_define_response, result - # ) - # ) - # - # except Exception as e: - # print(traceback.format_exc()) - # logger.error("model response parase faild!" + str(e)) - # self.current_message.add_view_message( - # f"""ERROR!{str(e)}\n {ai_response_text} """ - # ) - # ### 对话记录存储 - # self.memory.append(self.current_message) - def chat_show(self): super().chat_show() - def __list_to_prompt_str(list: List) -> str: - if not list: + def __list_to_prompt_str(self, list: List) -> str: + if list: separator = '\n' return separator.join(list) else: diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index e3469d7c2..44f564afe 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -52,7 +52,7 @@ PROMPT_NEED_NEED_STREAM_OUT = False chat_plugin_prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, - input_variables=["input", "table_info", "dialect", "top_k", "response"], + input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8fefdbfff..13bf38e5b 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -103,6 +103,11 @@ def gen_sqlgen_conversation(dbname): return f"数据库{dbname}的Schema信息如下: {message}\n" +def plugins_select_info(): + plugins_infos: dict = {} + for plugin in CFG.plugins: + plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name}) + return plugins_infos get_window_url_params = """ @@ -188,26 +193,27 @@ def post_process_code(code): return code -def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene: - if mode == conversation_types["default_knownledge"] and not db_selector: - return ChatScene.ChatKnowledge - elif mode == conversation_types["custome"] and not db_selector: - return ChatScene.ChatNewKnowledge - elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: - return ChatScene.ChatWithDb - - elif mode == conversation_types["auto_execute_plugin"] and not db_selector: +def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene: + if "插件模式" == selected: return ChatScene.ChatExecution + elif "知识问答" == selected: + if mode == conversation_types["default_knownledge"]: + return ChatScene.ChatKnowledge + elif mode == conversation_types["custome"]: + return ChatScene.ChatNewKnowledge else: - return ChatScene.ChatNormal + if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector: + return ChatScene.ChatWithDb + + return ChatScene.ChatNormal def http_bot( - state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request + state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request ): - logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}") + logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}") start_tstamp = time.time() - scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector) + scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector) print(f"当前对话模式:{scene.value}") model_name = CFG.LLM_MODEL @@ -216,6 +222,17 @@ def http_bot( chat_param = { "chat_session_id": state.conv_id, "db_name": db_selector, + "current_user_input": state.last_user_input, + } + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + chat.call() + state.messages[-1][-1] = f"{chat.current_ai_response()}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + elif ChatScene.ChatExecution == scene: + logger.info("插件模式对话走新的模式!") + chat_param = { + "chat_session_id": state.conv_id, + "plugin_selector": plugin_selector, "user_input": state.last_user_input, } chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) @@ -362,8 +379,8 @@ def http_bot( block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -396,6 +413,11 @@ def change_tab(): autogpt = True +def change_func(xx): + print("123") + print(str(xx)) + + def build_single_model_ui(): notice_markdown = """ # DB-GPT @@ -430,11 +452,18 @@ def build_single_model_ui(): label="最大输出Token数", ) - tabs = gr.Tabs() + def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + return evt.value + + selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") + tabs.select(on_select, None, selected) + with tabs: tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") + tab_sql.select(on_select, None, None) with tab_sql: print("tab_sql in...") # TODO A selector to choose database @@ -452,18 +481,26 @@ def build_single_model_ui(): sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN") + # tab_plugin.select(change_func) with tab_plugin: print("tab_plugin in...") with gr.Row(elem_id="plugin_selector"): # TODO plugin_selector = gr.Dropdown( label="请选择插件", - choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""], - value="datadance-ddl-excutor", + choices=list(plugins_select_info().keys()), + value="", interactive=True, show_label=True, + type="value" ).style(container=False) + def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData + print(f"You selected {evt.value} at {evt.index} from {evt.target}") + return plugins_select_info().get(evt.value) + + plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") + plugin_selector.select(plugin_change, None, plugin_selected) tab_qa = gr.TabItem("知识问答", elem_id="QA") with tab_qa: @@ -517,7 +554,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -526,7 +563,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -534,7 +571,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, mode, sql_mode, db_selector, temperature, max_output_tokens], + [state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) vs_add.click( @@ -557,10 +594,10 @@ def build_single_model_ui(): def build_webdemo(): with gr.Blocks( - title="数据库智能助手", - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title="数据库智能助手", + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) (