From 76854aece26b17d2f98605723ab747e60bbeb1d7 Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Fri, 20 Oct 2023 14:21:07 +0800 Subject: [PATCH] feat(ChatAgent): ChatAgent doucument add ChatAgent doucument --- pilot/scene/chat_db/auto_execute/chat.py | 13 ++--- pilot/scene/chat_db/auto_execute/prompt.py | 57 ++++++---------------- 2 files changed, 18 insertions(+), 52 deletions(-) diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 01e04be63..f92df7a3a 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -5,7 +5,6 @@ from pilot.scene.base import ChatScene from pilot.common.sql_database import Database from pilot.configs.config import Config from pilot.scene.chat_db.auto_execute.prompt import prompt -from pilot.base_modules.agent.commands.command_mange import ApiCall CFG = Config() @@ -38,7 +37,6 @@ class ChatWithDbAutoExecute(BaseChat): self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 200 - self.api_call = ApiCall(display_registry=CFG.command_disply) def generate_input_values(self): """ @@ -71,11 +69,6 @@ class ChatWithDbAutoExecute(BaseChat): } return input_values - def stream_plugin_call(self, text): - text = text.replace("\n", " ") - print(f"stream_plugin_call:{text}") - return self.api_call.run_display_sql(text, self.database.run_to_df) - # - # def do_action(self, prompt_response): - # print(f"do_action:{prompt_response}") - # return self.database.run(prompt_response.sql) + def do_action(self, prompt_response): + print(f"do_action:{prompt_response}") + return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index d9b67af39..abc889cec 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,51 +8,24 @@ from pilot.scene.chat_db.auto_execute.example import sql_data_example CFG = Config() +PROMPT_SCENE_DEFINE = "You are a SQL expert. " -_PROMPT_SCENE_DEFINE_EN = "You are a database expert. " -_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. " - -_DEFAULT_TEMPLATE_EN = """ +_DEFAULT_TEMPLATE = """ Given an input question, create a syntactically correct {dialect} sql. -Table structure information: - {table_info} -Constraint: -1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will. -2. Do not query columns that do not exist. Pay attention to which column is in which table. -3. Replace the corresponding sql into the sql field in the returned result -4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results. -5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:response_tableSQL Query to run -Please make sure to respond as following format: - thoughts summary to say to user.response_tableSQL Query to run - + +Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. +Use as few tables as possible when querying. +Only use the following tables schema to generate sql: +{table_info} +Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + Question: {input} + +Respond in JSON format as following format: +{response} +Ensure the response is correct json and can be parsed by Python json.loads """ -_DEFAULT_TEMPLATE_ZH = """ -给定一个输入问题,创建一个语法正确的 {dialect} sql。 -已知表结构信息: - {table_info} - -约束: -1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 -2. 不要查询不存在的列,注意哪一列位于哪张表中。 -3.将对应的sql替换到返回结果中的sql字段中 -4.除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。 - -请务必按照以下格式回复: - 对用户说的想法摘要。response_table要运行的 SQL - -问题:{input} -""" - -_DEFAULT_TEMPLATE = ( - _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH -) - -PROMPT_SCENE_DEFINE = ( - _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH -) - RESPONSE_FORMAT_SIMPLE = { "thoughts": "thoughts summary to say to user", "sql": "SQL Query to run", @@ -60,7 +33,7 @@ RESPONSE_FORMAT_SIMPLE = { PROMPT_SEP = SeparatorStyle.SINGLE.value -PROMPT_NEED_NEED_STREAM_OUT = True +PROMPT_NEED_NEED_STREAM_OUT = False # Temperature is a configuration hyperparameter that controls the randomness of language model output. # A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. @@ -70,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5 prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbExecute.value(), input_variables=["input", "table_info", "dialect", "top_k", "response"], - # response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), + response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT,