diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index ca7caab31..8c9e15d19 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -206,6 +206,10 @@ class BaseOutputParser(ABC): if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): logger.info("illegal json processing:\n" + cleaned_output) cleaned_output = self.__extract_json(cleaned_output) + + if not cleaned_output or len(cleaned_output) <=0: + return model_out_text + cleaned_output = ( cleaned_output.strip() .replace("\\n", " ") diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 008b9b0ad..361ccabc8 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -72,7 +72,7 @@ class ChatWithDbAutoExecute(BaseChat): ) input_values = { - # "input": self.current_user_input, + "user_input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_infos, diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index fb35a1d28..e7b0c5ac4 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -27,7 +27,7 @@ class DbChatOutputParser(BaseOutputParser): def __init__(self, sep: str, is_stream_out: bool): super().__init__(sep=sep, is_stream_out=is_stream_out) - def is_sql_statement(statement): + def is_sql_statement(self, statement): parsed = sqlparse.parse(statement) if not parsed: return False @@ -43,13 +43,17 @@ class DbChatOutputParser(BaseOutputParser): if self.is_sql_statement(clean_str): return SqlAction(clean_str, "") else: - response = json.loads(clean_str) - for key in sorted(response): - if key.strip() == "sql": - sql = response[key] - if key.strip() == "thoughts": - thoughts = response[key] - return SqlAction(sql, thoughts) + try: + response = json.loads(clean_str) + for key in sorted(response): + if key.strip() == "sql": + sql = response[key] + if key.strip() == "thoughts": + thoughts = response[key] + return SqlAction(sql, thoughts) + except Exception as e: + logging.error("json load faild") + return SqlAction("", clean_str) def parse_view_response(self, speak, data, prompt_response) -> str: @@ -57,6 +61,9 @@ class DbChatOutputParser(BaseOutputParser): api_call_element = ET.Element("chart-view") err_msg = None try: + if not prompt_response.sql or len(prompt_response.sql) <=0: + return f""" [Unresolvable return]\n{speak}""" + df = data(prompt_response.sql) param["type"] = "response_table" param["sql"] = prompt_response.sql diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index ab6c6ecdc..2ff5c8bb9 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -22,6 +22,8 @@ Constraint: 3.Use as few tables as possible when querying. 4.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions. +User Question: + {user_input} Please think step by step and respond according to the following JSON format: {response} Ensure the response is correct json and can be parsed by Python json.loads. @@ -37,7 +39,8 @@ _DEFAULT_TEMPLATE_ZH = """ 2. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 3. 请注意生成SQL时不要弄错表和列的关系 4. 请检查SQL的正确性,并保证正确的情况下优化查询性能 - +用户问题: + {user_input} 请一步步思考并按照以下JSON格式回复: {response} 确保返回正确的json并且可以被Python json.loads方法解析. diff --git a/pilot/server/static/404.html b/pilot/server/static/404.html index 8bc75c362..bcb398366 100644 --- a/pilot/server/static/404.html +++ b/pilot/server/static/404.html @@ -1 +1 @@ -