mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 16:50:24 +00:00
398 lines
15 KiB
Python
398 lines
15 KiB
Python
"""Module for managing commands and command plugins."""
|
||
|
||
import json
|
||
import logging
|
||
import xml.etree.ElementTree as ET
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
from dbgpt._private.pydantic import BaseModel
|
||
from dbgpt.agent.core.schema import Status
|
||
from dbgpt.util.json_utils import serialize
|
||
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PluginStatus(BaseModel):
|
||
"""A class representing the status of a plugin."""
|
||
|
||
name: str
|
||
location: List[int]
|
||
args: dict
|
||
status: Union[Status, str] = Status.TODO.value
|
||
logo_url: Optional[str] = None
|
||
api_result: Optional[str] = None
|
||
err_msg: Optional[str] = None
|
||
start_time: float = datetime.now().timestamp() * 1000
|
||
end_time: Optional[str] = None
|
||
|
||
df: Any = None
|
||
|
||
|
||
class ApiCall:
|
||
"""A class representing an API call."""
|
||
|
||
agent_prefix = "<api-call>"
|
||
agent_end = "</api-call>"
|
||
name_prefix = "<name>"
|
||
name_end = "</name>"
|
||
|
||
def __init__(
|
||
self,
|
||
plugin_generator: Any = None,
|
||
display_registry: Any = None,
|
||
backend_rendering: bool = False,
|
||
):
|
||
"""Create a new ApiCall object."""
|
||
self.plugin_status_map: Dict[str, PluginStatus] = {}
|
||
|
||
self.plugin_generator = plugin_generator
|
||
self.display_registry = display_registry
|
||
self.start_time = datetime.now().timestamp() * 1000
|
||
self.backend_rendering: bool = backend_rendering
|
||
|
||
def _is_need_wait_plugin_call(self, api_call_context):
|
||
start_agent_count = api_call_context.count(self.agent_prefix)
|
||
|
||
if start_agent_count > 0:
|
||
return True
|
||
else:
|
||
# Check the new character at the end
|
||
check_len = len(self.agent_prefix)
|
||
last_text = api_call_context[-check_len:]
|
||
for i in range(check_len):
|
||
text_tmp = last_text[-i:]
|
||
prefix_tmp = self.agent_prefix[:i]
|
||
if text_tmp == prefix_tmp:
|
||
return True
|
||
else:
|
||
i += 1
|
||
return False
|
||
|
||
def check_last_plugin_call_ready(self, all_context):
|
||
"""Check if the last plugin call is ready."""
|
||
start_agent_count = all_context.count(self.agent_prefix)
|
||
end_agent_count = all_context.count(self.agent_end)
|
||
|
||
if start_agent_count > 0 and start_agent_count == end_agent_count:
|
||
return True
|
||
return False
|
||
|
||
def _deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
|
||
error_md_tags = [
|
||
"```",
|
||
"```python",
|
||
"```xml",
|
||
"```json",
|
||
"```markdown",
|
||
"```sql",
|
||
]
|
||
if not include_end:
|
||
md_tag_end = ""
|
||
else:
|
||
md_tag_end = "```"
|
||
for tag in error_md_tags:
|
||
all_context = all_context.replace(
|
||
tag + api_context + md_tag_end, api_context
|
||
)
|
||
all_context = all_context.replace(
|
||
tag + "\n" + api_context + "\n" + md_tag_end, api_context
|
||
)
|
||
all_context = all_context.replace(
|
||
tag + " " + api_context + " " + md_tag_end, api_context
|
||
)
|
||
all_context = all_context.replace(tag + api_context, api_context)
|
||
return all_context
|
||
|
||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||
"""Return the view content."""
|
||
call_context_map = extract_content_open_ending(
|
||
all_context, self.agent_prefix, self.agent_end, True
|
||
)
|
||
for api_index, api_context in call_context_map.items():
|
||
api_status = self.plugin_status_map.get(api_context)
|
||
if api_status is not None:
|
||
if display_mode:
|
||
all_context = self._deal_error_md_tags(all_context, api_context)
|
||
if Status.FAILED.value == api_status.status:
|
||
err_msg = api_status.err_msg
|
||
all_context = all_context.replace(
|
||
api_context,
|
||
f'\n<span style="color:red">Error:</span>{err_msg}\n'
|
||
+ self.to_view_antv_vis(api_status),
|
||
)
|
||
else:
|
||
all_context = all_context.replace(
|
||
api_context, self.to_view_antv_vis(api_status)
|
||
)
|
||
|
||
else:
|
||
all_context = self._deal_error_md_tags(
|
||
all_context, api_context, False
|
||
)
|
||
all_context = all_context.replace(
|
||
api_context, self.to_view_text(api_status)
|
||
)
|
||
|
||
else:
|
||
# not ready api call view change
|
||
now_time = datetime.now().timestamp() * 1000
|
||
cost = (now_time - self.start_time) / 1000
|
||
cost_str = "{:.2f}".format(cost)
|
||
all_context = self._deal_error_md_tags(all_context, api_context)
|
||
|
||
all_context = all_context.replace(
|
||
api_context,
|
||
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
|
||
)
|
||
|
||
return all_context
|
||
|
||
def update_from_context(self, all_context):
|
||
"""Modify the plugin status map based on the context."""
|
||
api_context_map: Dict[int, str] = extract_content(
|
||
all_context, self.agent_prefix, self.agent_end, True
|
||
)
|
||
for api_index, api_context in api_context_map.items():
|
||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||
api_call_element = ET.fromstring(api_context)
|
||
api_name = api_call_element.find("name").text
|
||
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
|
||
api_name = api_name.replace("[", "").replace("]", "")
|
||
api_args = {}
|
||
args_elements = api_call_element.find("args")
|
||
for child_element in args_elements.iter():
|
||
api_args[child_element.tag] = child_element.text
|
||
|
||
api_status = self.plugin_status_map.get(api_context)
|
||
if api_status is None:
|
||
api_status = PluginStatus(
|
||
name=api_name, location=[api_index], args=api_args
|
||
)
|
||
self.plugin_status_map[api_context] = api_status
|
||
else:
|
||
api_status.location.append(api_index)
|
||
|
||
def _to_view_param_str(self, api_status):
|
||
param = {}
|
||
if api_status.name:
|
||
param["name"] = api_status.name
|
||
param["status"] = api_status.status
|
||
if api_status.logo_url:
|
||
param["logo"] = api_status.logo_url
|
||
|
||
if api_status.err_msg:
|
||
param["err_msg"] = api_status.err_msg
|
||
|
||
if api_status.api_result:
|
||
param["result"] = api_status.api_result
|
||
|
||
return json.dumps(param, default=serialize, ensure_ascii=False)
|
||
|
||
def to_view_text(self, api_status: PluginStatus):
|
||
"""Return the view content."""
|
||
api_call_element = ET.Element("dbgpt-view")
|
||
api_call_element.text = self._to_view_param_str(api_status)
|
||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||
return result.decode("utf-8")
|
||
|
||
def to_view_antv_vis(self, api_status: PluginStatus):
|
||
"""Return the vis content."""
|
||
if self.backend_rendering:
|
||
html_table = api_status.df.to_html(
|
||
index=False, escape=False, sparsify=False
|
||
)
|
||
table_str = "".join(html_table.split())
|
||
table_str = table_str.replace("\n", " ")
|
||
sql = api_status.args["sql"]
|
||
html = (
|
||
f' \n<div><b>[SQL]{sql}</b></div><div class="w-full overflow-auto">'
|
||
f"{table_str}</div>\n "
|
||
)
|
||
return html
|
||
else:
|
||
api_call_element = ET.Element("chart-view")
|
||
api_call_element.attrib["content"] = self._to_antv_vis_param(api_status)
|
||
api_call_element.text = "\n"
|
||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||
return result.decode("utf-8")
|
||
|
||
def _to_antv_vis_param(self, api_status: PluginStatus):
|
||
param = {}
|
||
if api_status.name:
|
||
param["type"] = api_status.name
|
||
if api_status.args:
|
||
param["sql"] = api_status.args["sql"]
|
||
|
||
data: Any = []
|
||
if api_status.api_result:
|
||
data = api_status.api_result
|
||
param["data"] = data
|
||
return json.dumps(param, ensure_ascii=False)
|
||
|
||
def run_display_sql(self, llm_text, sql_run_func):
|
||
"""Run the API calls for displaying SQL data."""
|
||
if self._is_need_wait_plugin_call(
|
||
llm_text
|
||
) and self.check_last_plugin_call_ready(llm_text):
|
||
# wait api call generate complete
|
||
self.update_from_context(llm_text)
|
||
for key, value in self.plugin_status_map.items():
|
||
if value.status == Status.TODO.value:
|
||
value.status = Status.RUNNING.value
|
||
logger.info(f"sql display execution:{value.name},{value.args}")
|
||
try:
|
||
sql = value.args["sql"]
|
||
if sql:
|
||
param = {
|
||
"df": sql_run_func(sql),
|
||
}
|
||
value.df = param["df"]
|
||
if self.display_registry.is_valid_command(value.name):
|
||
value.api_result = self.display_registry.call(
|
||
value.name, **param
|
||
)
|
||
else:
|
||
value.api_result = self.display_registry.call(
|
||
"response_table", **param
|
||
)
|
||
|
||
value.status = Status.COMPLETE.value
|
||
except Exception as e:
|
||
value.status = Status.FAILED.value
|
||
value.err_msg = str(e)
|
||
value.end_time = datetime.now().timestamp() * 1000
|
||
return self.api_view_context(llm_text, True)
|
||
|
||
def display_sql_llmvis(self, llm_text, sql_run_func):
|
||
"""Render charts using the Antv standard protocol.
|
||
|
||
Args:
|
||
llm_text: LLM response text
|
||
sql_run_func: sql run function
|
||
|
||
Returns:
|
||
ChartView protocol text
|
||
"""
|
||
try:
|
||
if self._is_need_wait_plugin_call(
|
||
llm_text
|
||
) and self.check_last_plugin_call_ready(llm_text):
|
||
# wait api call generate complete
|
||
self.update_from_context(llm_text)
|
||
for key, value in self.plugin_status_map.items():
|
||
if value.status == Status.TODO.value:
|
||
value.status = Status.RUNNING.value
|
||
logger.info(f"SQL execution:{value.name},{value.args}")
|
||
try:
|
||
sql = value.args["sql"]
|
||
if sql is not None and len(sql) > 0:
|
||
data_df = sql_run_func(sql)
|
||
value.df = data_df
|
||
value.api_result = json.loads(
|
||
data_df.to_json(
|
||
orient="records",
|
||
date_format="iso",
|
||
date_unit="s",
|
||
)
|
||
)
|
||
value.status = Status.COMPLETE.value
|
||
else:
|
||
value.status = Status.FAILED.value
|
||
value.err_msg = "No executable sql!"
|
||
|
||
except Exception as e:
|
||
logger.error(f"data prepare exception!{str(e)}")
|
||
value.status = Status.FAILED.value
|
||
value.err_msg = str(e)
|
||
value.end_time = datetime.now().timestamp() * 1000
|
||
except Exception as e:
|
||
logger.error("Api parsing exception", e)
|
||
raise ValueError("Api parsing exception," + str(e))
|
||
|
||
return self.api_view_context(llm_text, True)
|
||
|
||
def display_only_sql_vis(self, chart: dict, sql_2_df_func):
|
||
"""Display the chart using the vis standard protocol."""
|
||
err_msg = None
|
||
sql = chart.get("sql", None)
|
||
try:
|
||
param = {}
|
||
df = sql_2_df_func(sql)
|
||
if not sql or len(sql) <= 0:
|
||
return None
|
||
|
||
param["sql"] = sql
|
||
param["type"] = chart.get("display_type", "response_table")
|
||
param["title"] = chart.get("title", "")
|
||
param["describe"] = chart.get("thought", "")
|
||
|
||
param["data"] = json.loads(
|
||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||
)
|
||
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
||
except Exception as e:
|
||
logger.error("parse_view_response error!" + str(e))
|
||
err_param = {"sql": f"{sql}", "type": "response_table", "data": []}
|
||
err_msg = str(e)
|
||
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
||
|
||
# api_call_element.text = view_json_str
|
||
result = f"```vis-chart\n{view_json_str}\n```"
|
||
if err_msg:
|
||
return f"""<span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
|
||
else:
|
||
return result
|
||
|
||
def display_dashboard_vis(
|
||
self, charts: List[dict], sql_2_df_func, title: Optional[str] = None
|
||
):
|
||
"""Display the dashboard using the vis standard protocol."""
|
||
err_msg = None
|
||
view_json_str = None
|
||
|
||
chart_items = []
|
||
try:
|
||
if not charts or len(charts) <= 0:
|
||
return "Have no chart data!"
|
||
for chart in charts:
|
||
param = {}
|
||
sql = chart.get("sql", "")
|
||
param["sql"] = sql
|
||
param["type"] = chart.get("display_type", "response_table")
|
||
param["title"] = chart.get("title", "")
|
||
param["describe"] = chart.get("thought", "")
|
||
try:
|
||
df = sql_2_df_func(sql)
|
||
param["data"] = json.loads(
|
||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||
)
|
||
except Exception as e:
|
||
param["data"] = []
|
||
param["err_msg"] = str(e)
|
||
chart_items.append(param)
|
||
|
||
dashboard_param = {
|
||
"data": chart_items,
|
||
"chart_count": len(chart_items),
|
||
"title": title,
|
||
"display_strategy": "default",
|
||
"style": "default",
|
||
}
|
||
view_json_str = json.dumps(
|
||
dashboard_param, default=serialize, ensure_ascii=False
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("parse_view_response error!" + str(e))
|
||
return f"```error\nReport rendering exception!{str(e)}\n```"
|
||
|
||
result = f"```vis-dashboard\n{view_json_str}\n```"
|
||
if err_msg:
|
||
return (
|
||
f"""\\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
|
||
)
|
||
else:
|
||
return result
|