mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
fix sql not found error for chat data (#2152)
Signed-off-by: shanhaikang.shk <shanhaikang.shk@oceanbase.com>
This commit is contained in:
parent
4fa60037dd
commit
af2d042aa7
@ -18,12 +18,14 @@ class SqlAction(NamedTuple):
|
|||||||
sql: str
|
sql: str
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
display: str
|
display: str
|
||||||
|
direct_response: str
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Dict]:
|
def to_dict(self) -> Dict[str, Dict]:
|
||||||
return {
|
return {
|
||||||
"sql": self.sql,
|
"sql": self.sql,
|
||||||
"thoughts": self.thoughts,
|
"thoughts": self.thoughts,
|
||||||
"display": self.display,
|
"display": self.display,
|
||||||
|
"direct_response": self.direct_response,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +50,7 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
logger.info(f"clean prompt response: {clean_str}")
|
logger.info(f"clean prompt response: {clean_str}")
|
||||||
# Compatible with community pure sql output model
|
# Compatible with community pure sql output model
|
||||||
if self.is_sql_statement(clean_str):
|
if self.is_sql_statement(clean_str):
|
||||||
return SqlAction(clean_str, "", "")
|
return SqlAction(clean_str, "", "", "")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
response = json.loads(clean_str, strict=False)
|
response = json.loads(clean_str, strict=False)
|
||||||
@ -59,10 +61,12 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
thoughts = response[key]
|
thoughts = response[key]
|
||||||
if key.strip() == "display_type":
|
if key.strip() == "display_type":
|
||||||
display = response[key]
|
display = response[key]
|
||||||
return SqlAction(sql, thoughts, display)
|
if key.strip() == "direct_response":
|
||||||
|
resp = response[key]
|
||||||
|
return SqlAction(sql, thoughts, display, resp)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"json load failed:{clean_str}")
|
logger.error(f"json load failed:{clean_str}")
|
||||||
return SqlAction("", clean_str, "")
|
return SqlAction("", clean_str, "", "")
|
||||||
|
|
||||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||||
param = {}
|
param = {}
|
||||||
@ -70,17 +74,25 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
err_msg = None
|
err_msg = None
|
||||||
success = False
|
success = False
|
||||||
try:
|
try:
|
||||||
if not prompt_response.sql or len(prompt_response.sql) <= 0:
|
if (
|
||||||
|
not prompt_response.direct_response
|
||||||
|
or len(prompt_response.direct_response) <= 0
|
||||||
|
) and (not prompt_response.sql or len(prompt_response.sql) <= 0):
|
||||||
raise AppActionException("Can not find sql in response", speak)
|
raise AppActionException("Can not find sql in response", speak)
|
||||||
|
|
||||||
df = data(prompt_response.sql)
|
if prompt_response.sql:
|
||||||
param["type"] = prompt_response.display
|
df = data(prompt_response.sql)
|
||||||
param["sql"] = prompt_response.sql
|
param["type"] = prompt_response.display
|
||||||
param["data"] = json.loads(
|
param["sql"] = prompt_response.sql
|
||||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
param["data"] = json.loads(
|
||||||
)
|
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||||
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
)
|
||||||
success = True
|
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
||||||
|
success = True
|
||||||
|
elif prompt_response.direct_response:
|
||||||
|
speak = prompt_response.direct_response
|
||||||
|
view_json_str = ""
|
||||||
|
success = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("parse_view_response error!" + str(e))
|
logger.error("parse_view_response error!" + str(e))
|
||||||
err_param = {
|
err_param = {
|
||||||
@ -93,8 +105,11 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
||||||
|
|
||||||
# api_call_element.text = view_json_str
|
# api_call_element.text = view_json_str
|
||||||
api_call_element.set("content", view_json_str)
|
if len(view_json_str) != 0:
|
||||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
api_call_element.set("content", view_json_str)
|
||||||
|
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||||
|
else:
|
||||||
|
result = b""
|
||||||
if not success:
|
if not success:
|
||||||
view_content = (
|
view_content = (
|
||||||
f'{speak} \\n <span style="color:red">ERROR!</span>'
|
f'{speak} \\n <span style="color:red">ERROR!</span>'
|
||||||
|
@ -71,6 +71,7 @@ PROMPT_SCENE_DEFINE = (
|
|||||||
|
|
||||||
RESPONSE_FORMAT_SIMPLE = {
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
"thoughts": "thoughts summary to say to user",
|
"thoughts": "thoughts summary to say to user",
|
||||||
|
"direct_response": "If the context is sufficient to answer user, reply directly without sql",
|
||||||
"sql": "SQL Query to run",
|
"sql": "SQL Query to run",
|
||||||
"display_type": "Data display method",
|
"display_type": "Data display method",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user