chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-06-01 14:32:55 +08:00
parent 1c75dda0a0
commit bacc31658e
3 changed files with 7 additions and 43 deletions

View File

@ -47,44 +47,10 @@ class ChatWithDbAutoExecute(BaseChat):
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), "top_k": str(self.top_k),
"dialect": self.database.dialect, "dialect": self.database.dialect,
# "table_info": self.database.table_simple_info(self.db_connect) "table_info": self.database.table_simple_info(self.db_connect)
"table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) # "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
} }
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql) return self.database.run(self.db_connect, prompt_response.sql)
if __name__ == "__main__":
db = CFG.local_db
connect = db.get_session("gpt-user")
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test1';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test2';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """INSERT INTO `gpt-user`.users
(user_name, phone, email, city, create_time, last_login_time)
VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL);
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test3';
""")
print(str(db.get_session_db(connect)))
print(str(results))

View File

@ -254,7 +254,6 @@ def http_bot(
"db_name": db_selector, "db_name": db_selector,
"user_input": state.last_user_input "user_input": state.last_user_input
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatWithDbQA == scene: elif ChatScene.ChatWithDbQA == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -263,7 +262,6 @@ def http_bot(
"db_name": db_selector, "db_name": db_selector,
"user_input": state.last_user_input, "user_input": state.last_user_input,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatExecution == scene: elif ChatScene.ChatExecution == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -272,7 +270,6 @@ def http_bot(
"plugin_selector": plugin_selector, "plugin_selector": plugin_selector,
"user_input": state.last_user_input, "user_input": state.last_user_input,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNormal == scene: elif ChatScene.ChatNormal == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -280,7 +277,6 @@ def http_bot(
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"user_input": state.last_user_input, "user_input": state.last_user_input,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatKnowledge == scene: elif ChatScene.ChatKnowledge == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -288,7 +284,6 @@ def http_bot(
"chat_session_id": state.conv_id, "chat_session_id": state.conv_id,
"user_input": state.last_user_input, "user_input": state.last_user_input,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNewKnowledge == scene: elif ChatScene.ChatNewKnowledge == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -297,7 +292,6 @@ def http_bot(
"user_input": state.last_user_input, "user_input": state.last_user_input,
"knowledge_name": knowledge_name "knowledge_name": knowledge_name
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatUrlKnowledge == scene: elif ChatScene.ChatUrlKnowledge == scene:
chat_param = { chat_param = {
"temperature": temperature, "temperature": temperature,
@ -306,8 +300,11 @@ def http_bot(
"user_input": state.last_user_input, "user_input": state.last_user_input,
"url": url_input "url": url_input
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param) else:
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
if not chat.prompt_template.stream_out: if not chat.prompt_template.stream_out:
logger.info("not stream out, wait model response!") logger.info("not stream out, wait model response!")
state.messages[-1][-1] = chat.nostream_call() state.messages[-1][-1] = chat.nostream_call()

View File

@ -22,6 +22,7 @@ class MysqlSummary(DBSummary):
self.db = CFG.local_db self.db = CFG.local_db
self.db.get_session(name) self.db.get_session(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(), users=self.db.get_users(),
grant=self.db.get_grants(), grant=self.db.get_grants(),