fix sql not found error for chat data (#2152)

Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
GITHUBear 2024-12-03 21:49:53 +08:00 committed by GitHub
parent 4fa60037dd
commit af2d042aa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 14 deletions

View File

@ -18,12 +18,14 @@ class SqlAction(NamedTuple):
sql: str
thoughts: Dict
display: str
direct_response: str
def to_dict(self) -> Dict[str, Dict]:
return {
"sql": self.sql,
"thoughts": self.thoughts,
"display": self.display,
"direct_response": self.direct_response,
}
@ -48,7 +50,7 @@ class DbChatOutputParser(BaseOutputParser):
logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model
if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "", "")
return SqlAction(clean_str, "", "", "")
else:
try:
response = json.loads(clean_str, strict=False)
@ -59,10 +61,12 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key]
if key.strip() == "display_type":
display = response[key]
return SqlAction(sql, thoughts, display)
if key.strip() == "direct_response":
resp = response[key]
return SqlAction(sql, thoughts, display, resp)
except Exception as e:
logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "")
return SqlAction("", clean_str, "", "")
def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
@ -70,17 +74,25 @@ class DbChatOutputParser(BaseOutputParser):
err_msg = None
success = False
try:
if not prompt_response.sql or len(prompt_response.sql) <= 0:
if (
not prompt_response.direct_response
or len(prompt_response.direct_response) <= 0
) and (not prompt_response.sql or len(prompt_response.sql) <= 0):
raise AppActionException("Can not find sql in response", speak)
df = data(prompt_response.sql)
param["type"] = prompt_response.display
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
if prompt_response.sql:
df = data(prompt_response.sql)
param["type"] = prompt_response.display
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
elif prompt_response.direct_response:
speak = prompt_response.direct_response
view_json_str = ""
success = True
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {
@ -93,8 +105,11 @@ class DbChatOutputParser(BaseOutputParser):
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
# api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
if len(view_json_str) != 0:
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
else:
result = b""
if not success:
view_content = (
f'{speak} \\n <span style="color:red">ERROR!</span>'

View File

@ -71,6 +71,7 @@ PROMPT_SCENE_DEFINE = (
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"direct_response": "If the context is sufficient to answer user, reply directly without sql",
"sql": "SQL Query to run",
"display_type": "Data display method",
}