fix sql not found error for chat data (#2152)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
GITHUBear 2024-12-03 21:49:53 +08:00 committed by GitHub
parent 4fa60037dd
commit af2d042aa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 14 deletions

View File

@ -18,12 +18,14 @@ class SqlAction(NamedTuple):
sql: str sql: str
thoughts: Dict thoughts: Dict
display: str display: str
direct_response: str
def to_dict(self) -> Dict[str, Dict]: def to_dict(self) -> Dict[str, Dict]:
return { return {
"sql": self.sql, "sql": self.sql,
"thoughts": self.thoughts, "thoughts": self.thoughts,
"display": self.display, "display": self.display,
"direct_response": self.direct_response,
} }
@ -48,7 +50,7 @@ class DbChatOutputParser(BaseOutputParser):
logger.info(f"clean prompt response: {clean_str}") logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model # Compatible with community pure sql output model
if self.is_sql_statement(clean_str): if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "", "") return SqlAction(clean_str, "", "", "")
else: else:
try: try:
response = json.loads(clean_str, strict=False) response = json.loads(clean_str, strict=False)
@ -59,10 +61,12 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key] thoughts = response[key]
if key.strip() == "display_type": if key.strip() == "display_type":
display = response[key] display = response[key]
return SqlAction(sql, thoughts, display) if key.strip() == "direct_response":
resp = response[key]
return SqlAction(sql, thoughts, display, resp)
except Exception as e: except Exception as e:
logger.error(f"json load failed:{clean_str}") logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "") return SqlAction("", clean_str, "", "")
def parse_view_response(self, speak, data, prompt_response) -> str: def parse_view_response(self, speak, data, prompt_response) -> str:
param = {} param = {}
@ -70,17 +74,25 @@ class DbChatOutputParser(BaseOutputParser):
err_msg = None err_msg = None
success = False success = False
try: try:
if not prompt_response.sql or len(prompt_response.sql) <= 0: if (
not prompt_response.direct_response
or len(prompt_response.direct_response) <= 0
) and (not prompt_response.sql or len(prompt_response.sql) <= 0):
raise AppActionException("Can not find sql in response", speak) raise AppActionException("Can not find sql in response", speak)
df = data(prompt_response.sql) if prompt_response.sql:
param["type"] = prompt_response.display df = data(prompt_response.sql)
param["sql"] = prompt_response.sql param["type"] = prompt_response.display
param["data"] = json.loads( param["sql"] = prompt_response.sql
df.to_json(orient="records", date_format="iso", date_unit="s") param["data"] = json.loads(
) df.to_json(orient="records", date_format="iso", date_unit="s")
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False) )
success = True view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
elif prompt_response.direct_response:
speak = prompt_response.direct_response
view_json_str = ""
success = True
except Exception as e: except Exception as e:
logger.error("parse_view_response error!" + str(e)) logger.error("parse_view_response error!" + str(e))
err_param = { err_param = {
@ -93,8 +105,11 @@ class DbChatOutputParser(BaseOutputParser):
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
# api_call_element.text = view_json_str # api_call_element.text = view_json_str
api_call_element.set("content", view_json_str) if len(view_json_str) != 0:
result = ET.tostring(api_call_element, encoding="utf-8") api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
else:
result = b""
if not success: if not success:
view_content = ( view_content = (
f'{speak} \\n <span style="color:red">ERROR!</span>' f'{speak} \\n <span style="color:red">ERROR!</span>'

View File

@ -71,6 +71,7 @@ PROMPT_SCENE_DEFINE = (
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user", "thoughts": "thoughts summary to say to user",
"direct_response": "If the context is sufficient to answer user, reply directly without sql",
"sql": "SQL Query to run", "sql": "SQL Query to run",
"display_type": "Data display method", "display_type": "Data display method",
} }