diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py index 7aad46bd8..3dcae4554 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py +++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py @@ -18,12 +18,14 @@ class SqlAction(NamedTuple): sql: str thoughts: Dict display: str + direct_response: str def to_dict(self) -> Dict[str, Dict]: return { "sql": self.sql, "thoughts": self.thoughts, "display": self.display, + "direct_response": self.direct_response, } @@ -48,7 +50,7 @@ class DbChatOutputParser(BaseOutputParser): logger.info(f"clean prompt response: {clean_str}") # Compatible with community pure sql output model if self.is_sql_statement(clean_str): - return SqlAction(clean_str, "", "") + return SqlAction(clean_str, "", "", "") else: try: response = json.loads(clean_str, strict=False) @@ -59,10 +61,12 @@ class DbChatOutputParser(BaseOutputParser): thoughts = response[key] if key.strip() == "display_type": display = response[key] - return SqlAction(sql, thoughts, display) + if key.strip() == "direct_response": + resp = response[key] + return SqlAction(sql, thoughts, display, resp) except Exception as e: logger.error(f"json load failed:{clean_str}") - return SqlAction("", clean_str, "") + return SqlAction("", clean_str, "", "") def parse_view_response(self, speak, data, prompt_response) -> str: param = {} @@ -70,17 +74,25 @@ class DbChatOutputParser(BaseOutputParser): err_msg = None success = False try: - if not prompt_response.sql or len(prompt_response.sql) <= 0: + if ( + not prompt_response.direct_response + or len(prompt_response.direct_response) <= 0 + ) and (not prompt_response.sql or len(prompt_response.sql) <= 0): raise AppActionException("Can not find sql in response", speak) - df = data(prompt_response.sql) - param["type"] = prompt_response.display - param["sql"] = prompt_response.sql - param["data"] = json.loads( - df.to_json(orient="records", date_format="iso", date_unit="s") - ) - view_json_str = json.dumps(param, default=serialize, ensure_ascii=False) - success = True + if prompt_response.sql: + df = data(prompt_response.sql) + param["type"] = prompt_response.display + param["sql"] = prompt_response.sql + param["data"] = json.loads( + df.to_json(orient="records", date_format="iso", date_unit="s") + ) + view_json_str = json.dumps(param, default=serialize, ensure_ascii=False) + success = True + elif prompt_response.direct_response: + speak = prompt_response.direct_response + view_json_str = "" + success = True except Exception as e: logger.error("parse_view_response error!" + str(e)) err_param = { @@ -93,8 +105,11 @@ class DbChatOutputParser(BaseOutputParser): view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) # api_call_element.text = view_json_str - api_call_element.set("content", view_json_str) - result = ET.tostring(api_call_element, encoding="utf-8") + if len(view_json_str) != 0: + api_call_element.set("content", view_json_str) + result = ET.tostring(api_call_element, encoding="utf-8") + else: + result = b"" if not success: view_content = ( f'{speak} \\n ERROR!' diff --git a/dbgpt/app/scene/chat_db/auto_execute/prompt.py b/dbgpt/app/scene/chat_db/auto_execute/prompt.py index 52b153215..4eed89b30 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/prompt.py +++ b/dbgpt/app/scene/chat_db/auto_execute/prompt.py @@ -71,6 +71,7 @@ PROMPT_SCENE_DEFINE = ( RESPONSE_FORMAT_SIMPLE = { "thoughts": "thoughts summary to say to user", + "direct_response": "If the context is sufficient to answer user, reply directly without sql", "sql": "SQL Query to run", "display_type": "Data display method", }