mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-30 22:28:48 +00:00
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import json
|
|
from typing import Dict, NamedTuple
|
|
import logging
|
|
import xml.etree.ElementTree as ET
|
|
from pilot.common.json_utils import serialize
|
|
from pilot.out_parser.base import BaseOutputParser, T
|
|
from pilot.configs.config import Config
|
|
from pilot.scene.chat_db.data_loader import DbDataLoader
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class SqlAction(NamedTuple):
|
|
sql: str
|
|
thoughts: Dict
|
|
|
|
def to_dict(self) -> Dict[str, Dict]:
|
|
return {"sql": self.sql, "thoughts": self.thoughts}
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DbChatOutputParser(BaseOutputParser):
|
|
def __init__(self, sep: str, is_stream_out: bool):
|
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
|
|
|
def parse_prompt_response(self, model_out_text):
|
|
clean_str = super().parse_prompt_response(model_out_text)
|
|
logging.info("clean prompt response:", clean_str)
|
|
response = json.loads(clean_str)
|
|
for key in sorted(response):
|
|
if key.strip() == "sql":
|
|
sql = response[key]
|
|
if key.strip() == "thoughts":
|
|
thoughts = response[key]
|
|
return SqlAction(sql, thoughts)
|
|
|
|
def parse_view_response(self, speak, data, prompt_response) -> str:
|
|
|
|
param = {}
|
|
api_call_element = ET.Element("chart-view")
|
|
err_msg = None
|
|
try:
|
|
df = data(prompt_response.sql)
|
|
param["type"] = "response_table"
|
|
param["sql"] = prompt_response.sql
|
|
param["data"] = json.loads(df.to_json(orient='records', date_format='iso', date_unit='s'))
|
|
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
|
except Exception as e:
|
|
logger.error("parse_view_response error!" + str(e))
|
|
err_param = {}
|
|
err_param["sql"] = f'{prompt_response.sql}'
|
|
err_param["type"] = "response_table"
|
|
# err_param["err_msg"] = str(e)
|
|
err_param["data"] = []
|
|
err_msg = str(e)
|
|
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
|
|
|
api_call_element.text = view_json_str
|
|
result = ET.tostring(api_call_element, encoding="utf-8")
|
|
if err_msg:
|
|
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""
|
|
else:
|
|
return speak + "\n" + result.decode("utf-8")
|