"""Module for managing commands and command plugins."""
import json
import logging
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt.agent.core.schema import Status
from dbgpt.util.json_utils import serialize
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
logger = logging.getLogger(__name__)
class PluginStatus(BaseModel):
"""A class representing the status of a plugin."""
name: str
location: List[int]
args: dict
status: Union[Status, str] = Status.TODO.value
logo_url: Optional[str] = None
api_result: Optional[str] = None
err_msg: Optional[str] = None
start_time: float = datetime.now().timestamp() * 1000
end_time: Optional[str] = None
df: Any = None
class ApiCall:
"""A class representing an API call."""
agent_prefix = ""
agent_end = ""
name_prefix = ""
name_end = ""
def __init__(
self,
plugin_generator: Any = None,
display_registry: Any = None,
backend_rendering: bool = False,
):
"""Create a new ApiCall object."""
self.plugin_status_map: Dict[str, PluginStatus] = {}
self.plugin_generator = plugin_generator
self.display_registry = display_registry
self.start_time = datetime.now().timestamp() * 1000
self.backend_rendering: bool = backend_rendering
def _is_need_wait_plugin_call(self, api_call_context):
start_agent_count = api_call_context.count(self.agent_prefix)
if start_agent_count > 0:
return True
else:
# Check the new character at the end
check_len = len(self.agent_prefix)
last_text = api_call_context[-check_len:]
for i in range(check_len):
text_tmp = last_text[-i:]
prefix_tmp = self.agent_prefix[:i]
if text_tmp == prefix_tmp:
return True
else:
i += 1
return False
def check_last_plugin_call_ready(self, all_context):
"""Check if the last plugin call is ready."""
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
if start_agent_count > 0 and start_agent_count == end_agent_count:
return True
return False
def _deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
error_md_tags = [
"```",
"```python",
"```xml",
"```json",
"```markdown",
"```sql",
]
if not include_end:
md_tag_end = ""
else:
md_tag_end = "```"
for tag in error_md_tags:
all_context = all_context.replace(
tag + api_context + md_tag_end, api_context
)
all_context = all_context.replace(
tag + "\n" + api_context + "\n" + md_tag_end, api_context
)
all_context = all_context.replace(
tag + " " + api_context + " " + md_tag_end, api_context
)
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
"""Return the view content."""
call_context_map = extract_content_open_ending(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
all_context = self._deal_error_md_tags(all_context, api_context)
if Status.FAILED.value == api_status.status:
err_msg = api_status.err_msg
all_context = all_context.replace(
api_context,
f'\nError:{err_msg}\n'
+ self.to_view_antv_vis(api_status),
)
else:
all_context = all_context.replace(
api_context, self.to_view_antv_vis(api_status)
)
else:
all_context = self._deal_error_md_tags(
all_context, api_context, False
)
all_context = all_context.replace(
api_context, self.to_view_text(api_status)
)
else:
# not ready api call view change
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = self._deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(
api_context,
f'\nWaiting...{cost_str}S\n',
)
return all_context
def update_from_context(self, all_context):
"""Modify the plugin status map based on the context."""
api_context_map: Dict[int, str] = extract_content(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
api_name = api_call_element.find("name").text
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
api_name = api_name.replace("[", "").replace("]", "")
api_args = {}
args_elements = api_call_element.find("args")
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
api_status = self.plugin_status_map.get(api_context)
if api_status is None:
api_status = PluginStatus(
name=api_name, location=[api_index], args=api_args
)
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
def _to_view_param_str(self, api_status):
param = {}
if api_status.name:
param["name"] = api_status.name
param["status"] = api_status.status
if api_status.logo_url:
param["logo"] = api_status.logo_url
if api_status.err_msg:
param["err_msg"] = api_status.err_msg
if api_status.api_result:
param["result"] = api_status.api_result
return json.dumps(param, default=serialize, ensure_ascii=False)
def to_view_text(self, api_status: PluginStatus):
"""Return the view content."""
api_call_element = ET.Element("dbgpt-view")
api_call_element.text = self._to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
def to_view_antv_vis(self, api_status: PluginStatus):
"""Return the vis content."""
if self.backend_rendering:
html_table = api_status.df.to_html(
index=False, escape=False, sparsify=False
)
table_str = "".join(html_table.split())
table_str = table_str.replace("\n", " ")
sql = api_status.args["sql"]
html = (
f' \n
[SQL]{sql}
'
f"{table_str}
\n "
)
return html
else:
api_call_element = ET.Element("chart-view")
api_call_element.attrib["content"] = self._to_antv_vis_param(api_status)
api_call_element.text = "\n"
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
def _to_antv_vis_param(self, api_status: PluginStatus):
param = {}
if api_status.name:
param["type"] = api_status.name
if api_status.args:
param["sql"] = api_status.args["sql"]
data: Any = []
if api_status.api_result:
data = api_status.api_result
param["data"] = data
return json.dumps(param, ensure_ascii=False)
def run_display_sql(self, llm_text, sql_run_func):
"""Run the API calls for displaying SQL data."""
if self._is_need_wait_plugin_call(
llm_text
) and self.check_last_plugin_call_ready(llm_text):
# wait api call generate complete
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logger.info(f"sql display execution:{value.name},{value.args}")
try:
sql = value.args["sql"]
if sql:
param = {
"df": sql_run_func(sql),
}
value.df = param["df"]
if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(
value.name, **param
)
else:
value.api_result = self.display_registry.call(
"response_table", **param
)
value.status = Status.COMPLETE.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)
def display_sql_llmvis(self, llm_text, sql_run_func):
"""Render charts using the Antv standard protocol.
Args:
llm_text: LLM response text
sql_run_func: sql run function
Returns:
ChartView protocol text
"""
try:
if self._is_need_wait_plugin_call(
llm_text
) and self.check_last_plugin_call_ready(llm_text):
# wait api call generate complete
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logger.info(f"SQL execution:{value.name},{value.args}")
try:
sql = value.args["sql"]
if sql is not None and len(sql) > 0:
data_df = sql_run_func(sql)
value.df = data_df
value.api_result = json.loads(
data_df.to_json(
orient="records",
date_format="iso",
date_unit="s",
)
)
value.status = Status.COMPLETE.value
else:
value.status = Status.FAILED.value
value.err_msg = "No executable sql!"
except Exception as e:
logger.error(f"data prepare exception!{str(e)}")
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
except Exception as e:
logger.error("Api parsing exception", e)
raise ValueError("Api parsing exception," + str(e))
return self.api_view_context(llm_text, True)
def display_only_sql_vis(self, chart: dict, sql_2_df_func):
"""Display the chart using the vis standard protocol."""
err_msg = None
sql = chart.get("sql", None)
try:
param = {}
df = sql_2_df_func(sql)
if not sql or len(sql) <= 0:
return None
param["sql"] = sql
param["type"] = chart.get("display_type", "response_table")
param["title"] = chart.get("title", "")
param["describe"] = chart.get("thought", "")
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {"sql": f"{sql}", "type": "response_table", "data": []}
err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
# api_call_element.text = view_json_str
result = f"```vis-chart\n{view_json_str}\n```"
if err_msg:
return f"""ERROR!{err_msg} \n {result}"""
else:
return result
def display_dashboard_vis(
self, charts: List[dict], sql_2_df_func, title: Optional[str] = None
):
"""Display the dashboard using the vis standard protocol."""
err_msg = None
view_json_str = None
chart_items = []
try:
if not charts or len(charts) <= 0:
return "Have no chart data!"
for chart in charts:
param = {}
sql = chart.get("sql", "")
param["sql"] = sql
param["type"] = chart.get("display_type", "response_table")
param["title"] = chart.get("title", "")
param["describe"] = chart.get("thought", "")
try:
df = sql_2_df_func(sql)
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
)
except Exception as e:
param["data"] = []
param["err_msg"] = str(e)
chart_items.append(param)
dashboard_param = {
"data": chart_items,
"chart_count": len(chart_items),
"title": title,
"display_strategy": "default",
"style": "default",
}
view_json_str = json.dumps(
dashboard_param, default=serialize, ensure_ascii=False
)
except Exception as e:
logger.error("parse_view_response error!" + str(e))
return f"```error\nReport rendering exception!{str(e)}\n```"
result = f"```vis-dashboard\n{view_json_str}\n```"
if err_msg:
return (
f"""\\n ERROR!{err_msg} \n {result}"""
)
else:
return result