From bacc31658e14e1844ce81eda1e6ed69d40147763 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Thu, 1 Jun 2023 14:32:55 +0800 Subject: [PATCH] chat with plugin bug fix --- pilot/scene/chat_db/auto_execute/chat.py | 38 ++---------------------- pilot/server/webserver.py | 11 +++---- pilot/summary/mysql_db_summary.py | 1 + 3 files changed, 7 insertions(+), 43 deletions(-) diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 0ef8bc701..2882fb1cc 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -47,44 +47,10 @@ class ChatWithDbAutoExecute(BaseChat): "input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, - # "table_info": self.database.table_simple_info(self.db_connect) - "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + "table_info": self.database.table_simple_info(self.db_connect) + # "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values def do_with_prompt_response(self, prompt_response): return self.database.run(self.db_connect, prompt_response.sql) - - - -if __name__ == "__main__": - db = CFG.local_db - connect = db.get_session("gpt-user") - - results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time - FROM `gpt-user`.users - WHERE user_name='test1'; - """) - - print(str(db.get_session_db(connect))) - print(str(results)) - results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time - FROM `gpt-user`.users - WHERE user_name='test2'; - """) - print(str(db.get_session_db(connect))) - print(str(results)) - - results = db.run(connect, """INSERT INTO `gpt-user`.users - (user_name, phone, email, city, create_time, last_login_time) - VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL); - """) - print(str(db.get_session_db(connect))) - print(str(results)) - - results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time - FROM `gpt-user`.users - WHERE user_name='test3'; - """) - print(str(db.get_session_db(connect))) - print(str(results)) \ No newline at end of file diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 132944fb6..e51ef792c 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -254,7 +254,6 @@ def http_bot( "db_name": db_selector, "user_input": state.last_user_input } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatWithDbQA == scene: chat_param = { "temperature": temperature, @@ -263,7 +262,6 @@ def http_bot( "db_name": db_selector, "user_input": state.last_user_input, } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatExecution == scene: chat_param = { "temperature": temperature, @@ -272,7 +270,6 @@ def http_bot( "plugin_selector": plugin_selector, "user_input": state.last_user_input, } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatNormal == scene: chat_param = { "temperature": temperature, @@ -280,7 +277,6 @@ def http_bot( "chat_session_id": state.conv_id, "user_input": state.last_user_input, } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatKnowledge == scene: chat_param = { "temperature": temperature, @@ -288,7 +284,6 @@ def http_bot( "chat_session_id": state.conv_id, "user_input": state.last_user_input, } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatNewKnowledge == scene: chat_param = { "temperature": temperature, @@ -297,7 +292,6 @@ def http_bot( "user_input": state.last_user_input, "knowledge_name": knowledge_name } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) elif ChatScene.ChatUrlKnowledge == scene: chat_param = { "temperature": temperature, @@ -306,8 +300,11 @@ def http_bot( "user_input": state.last_user_input, "url": url_input } - chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) + else: + state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}" + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) if not chat.prompt_template.stream_out: logger.info("not stream out, wait model response!") state.messages[-1][-1] = chat.nostream_call() diff --git a/pilot/summary/mysql_db_summary.py b/pilot/summary/mysql_db_summary.py index e14aad9a3..3ed9b9171 100644 --- a/pilot/summary/mysql_db_summary.py +++ b/pilot/summary/mysql_db_summary.py @@ -22,6 +22,7 @@ class MysqlSummary(DBSummary): self.db = CFG.local_db self.db.get_session(name) + self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( users=self.db.get_users(), grant=self.db.get_grants(),