chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-05-31 18:09:20 +08:00
parent 3a46dfd3c2
commit ced9b581fc
9 changed files with 27 additions and 48 deletions

View File

@ -179,31 +179,22 @@ class BaseChat(ABC):
result = self.do_with_prompt_response(prompt_define_response) result = self.do_with_prompt_response(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
if hasattr(prompt_define_response.thoughts, "speak"): if isinstance(prompt_define_response.thoughts, dict):
self.current_message.add_view_message( if "speak" in prompt_define_response.thoughts:
self.prompt_template.output_parser.parse_view_response( speak_to_user = prompt_define_response.thoughts.get("speak")
prompt_define_response.thoughts.get("speak"), result else:
) speak_to_user = str(prompt_define_response.thoughts)
)
elif hasattr(prompt_define_response.thoughts, "reasoning"):
self.current_message.add_view_message(
self.prompt_template.output_parser.parse_view_response(
prompt_define_response.thoughts.get("reasoning"), result
)
)
else: else:
self.current_message.add_view_message( if hasattr(prompt_define_response.thoughts, "speak"):
self.prompt_template.output_parser.parse_view_response( speak_to_user = prompt_define_response.thoughts.get("speak")
prompt_define_response.thoughts, result elif hasattr(prompt_define_response.thoughts, "reasoning"):
) speak_to_user = prompt_define_response.thoughts.get("reasoning")
) else:
speak_to_user = prompt_define_response.thoughts
else: else:
self.current_message.add_view_message( speak_to_user = prompt_define_response
self.prompt_template.output_parser.parse_view_response( view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
prompt_define_response, result self.current_message.add_view_message(view_message)
)
)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -55,7 +55,7 @@ class ChatWithPlugin(BaseChat):
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
## plugin command run ## plugin command run
return execute_command(str(prompt_response), self.plugins_prompt_generator) return execute_command(str(prompt_response.command.get('name')), prompt_response.command.get('args',{}), self.plugins_prompt_generator)
def chat_show(self): def chat_show(self):
super().chat_show() super().chat_show()

View File

@ -23,8 +23,11 @@ class PluginChatOutputParser(BaseOutputParser):
command, thoughts = response["command"], response["thoughts"] command, thoughts = response["command"], response["thoughts"]
return PluginAction(command, thoughts) return PluginAction(command, thoughts)
def parse_view_response(self, ai_text) -> str: def parse_view_response(self, speak, data) -> str:
return super().parse_view_response(ai_text) ### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}" )
view_text = f"##### {speak}" + "\n" + str(data)
return view_text
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass

View File

@ -246,7 +246,7 @@ def http_bot(
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
): ):
logger.info(f"User message send!{state.conv_id},{selected}") logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
if chat_mode_title['knowledge_qa'] == selected: if chat_mode_title['knowledge_qa'] == selected:
scene: ChatScene = get_chat_mode(selected, mode) scene: ChatScene = get_chat_mode(selected, mode)
elif chat_mode_title['chat_use_plugin'] == selected: elif chat_mode_title['chat_use_plugin'] == selected:
@ -417,7 +417,6 @@ def build_single_model_ui():
value=dbs[0] if len(models) > 0 else "", value=dbs[0] if len(models) > 0 else "",
interactive=True, interactive=True,
show_label=True, show_label=True,
name="db_selector"
).style(container=False) ).style(container=False)
sql_mode = gr.Radio( sql_mode = gr.Radio(
@ -426,8 +425,7 @@ def build_single_model_ui():
get_lang_text("sql_generate_mode_none"), get_lang_text("sql_generate_mode_none"),
], ],
show_label=False, show_label=False,
value=get_lang_text("sql_generate_mode_none"), value=get_lang_text("sql_generate_mode_none")
name="sql_mode"
) )
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting")) sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting) sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
@ -444,12 +442,12 @@ def build_single_model_ui():
value="", value="",
interactive=True, interactive=True,
show_label=True, show_label=True,
type="value", type="value"
name="plugin_selector"
).style(container=False) ).style(container=False)
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}") print(f"You selected {evt.value} at {evt.index} from {evt.target}")
print(f"user plugin:{plugins_select_info().get(evt.value)}")
return plugins_select_info().get(evt.value) return plugins_select_info().get(evt.value)
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected") plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
@ -466,14 +464,13 @@ def build_single_model_ui():
], ],
show_label=False, show_label=False,
value=llm_native_dialogue, value=llm_native_dialogue,
name="mode"
) )
vs_setting = gr.Accordion( vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False, visible=False get_lang_text("configure_knowledge_base"), open=False, visible=False
) )
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False, name="url_input") url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
def show_url_input(evt:gr.SelectData): def show_url_input(evt:gr.SelectData):
if evt.value == url_knowledge_dialogue: if evt.value == url_knowledge_dialogue:
return gr.update(visible=True) return gr.update(visible=True)
@ -484,7 +481,7 @@ def build_single_model_ui():
with vs_setting: with vs_setting:
vs_name = gr.Textbox( vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True, name = "vs_name" label=get_lang_text("new_klg_name"), lines=1, interactive=True
) )
vs_add = gr.Button(get_lang_text("add_as_new_klg")) vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec: with gr.Column() as doc2vec:
@ -530,7 +527,7 @@ def build_single_model_ui():
gr.Markdown(learn_more_markdown) gr.Markdown(learn_more_markdown)
params = [plugin_selector, mode, sql_mode, db_selector, url_input, vs_name] params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
btn_list = [regenerate_btn, clear_btn] btn_list = [regenerate_btn, clear_btn]

Binary file not shown.