chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-06-01 14:32:55 +08:00
parent 1c75dda0a0
commit bacc31658e
3 changed files with 7 additions and 43 deletions

View File

@ -47,44 +47,10 @@ class ChatWithDbAutoExecute(BaseChat):
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
# "table_info": self.database.table_simple_info(self.db_connect)
"table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
"table_info": self.database.table_simple_info(self.db_connect)
# "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
def do_with_prompt_response(self, prompt_response):
return self.database.run(self.db_connect, prompt_response.sql)
if __name__ == "__main__":
db = CFG.local_db
connect = db.get_session("gpt-user")
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test1';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test2';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """INSERT INTO `gpt-user`.users
(user_name, phone, email, city, create_time, last_login_time)
VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL);
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test3';
""")
print(str(db.get_session_db(connect)))
print(str(results))

View File

@ -254,7 +254,6 @@ def http_bot(
"db_name": db_selector,
"user_input": state.last_user_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatWithDbQA == scene:
chat_param = {
"temperature": temperature,
@ -263,7 +262,6 @@ def http_bot(
"db_name": db_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatExecution == scene:
chat_param = {
"temperature": temperature,
@ -272,7 +270,6 @@ def http_bot(
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNormal == scene:
chat_param = {
"temperature": temperature,
@ -280,7 +277,6 @@ def http_bot(
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatKnowledge == scene:
chat_param = {
"temperature": temperature,
@ -288,7 +284,6 @@ def http_bot(
"chat_session_id": state.conv_id,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatNewKnowledge == scene:
chat_param = {
"temperature": temperature,
@ -297,7 +292,6 @@ def http_bot(
"user_input": state.last_user_input,
"knowledge_name": knowledge_name
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
elif ChatScene.ChatUrlKnowledge == scene:
chat_param = {
"temperature": temperature,
@ -306,8 +300,11 @@ def http_bot(
"user_input": state.last_user_input,
"url": url_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
else:
state.messages[-1][-1] = f"ERROR: Can't support scene!{scene}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
if not chat.prompt_template.stream_out:
logger.info("not stream out, wait model response!")
state.messages[-1][-1] = chat.nostream_call()

View File

@ -22,6 +22,7 @@ class MysqlSummary(DBSummary):
self.db = CFG.local_db
self.db.get_session(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(),
grant=self.db.get_grants(),