"""Module for managing commands and command plugins.""" import json import logging import xml.etree.ElementTree as ET from datetime import datetime from typing import Any, Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel from dbgpt.agent.core.schema import Status from dbgpt.util.json_utils import serialize from dbgpt.util.string_utils import extract_content, extract_content_open_ending logger = logging.getLogger(__name__) class PluginStatus(BaseModel): """A class representing the status of a plugin.""" name: str location: List[int] args: dict status: Union[Status, str] = Status.TODO.value logo_url: Optional[str] = None api_result: Optional[str] = None err_msg: Optional[str] = None start_time: float = datetime.now().timestamp() * 1000 end_time: Optional[str] = None df: Any = None class ApiCall: """A class representing an API call.""" agent_prefix = "" agent_end = "" name_prefix = "" name_end = "" def __init__( self, plugin_generator: Any = None, display_registry: Any = None, backend_rendering: bool = False, ): """Create a new ApiCall object.""" self.plugin_status_map: Dict[str, PluginStatus] = {} self.plugin_generator = plugin_generator self.display_registry = display_registry self.start_time = datetime.now().timestamp() * 1000 self.backend_rendering: bool = backend_rendering def _is_need_wait_plugin_call(self, api_call_context): start_agent_count = api_call_context.count(self.agent_prefix) if start_agent_count > 0: return True else: # Check the new character at the end check_len = len(self.agent_prefix) last_text = api_call_context[-check_len:] for i in range(check_len): text_tmp = last_text[-i:] prefix_tmp = self.agent_prefix[:i] if text_tmp == prefix_tmp: return True else: i += 1 return False def check_last_plugin_call_ready(self, all_context): """Check if the last plugin call is ready.""" start_agent_count = all_context.count(self.agent_prefix) end_agent_count = all_context.count(self.agent_end) if start_agent_count > 0 and start_agent_count == end_agent_count: return True return False def _deal_error_md_tags(self, all_context, api_context, include_end: bool = True): error_md_tags = [ "```", "```python", "```xml", "```json", "```markdown", "```sql", ] if not include_end: md_tag_end = "" else: md_tag_end = "```" for tag in error_md_tags: all_context = all_context.replace( tag + api_context + md_tag_end, api_context ) all_context = all_context.replace( tag + "\n" + api_context + "\n" + md_tag_end, api_context ) all_context = all_context.replace( tag + " " + api_context + " " + md_tag_end, api_context ) all_context = all_context.replace(tag + api_context, api_context) return all_context def api_view_context(self, all_context: str, display_mode: bool = False): """Return the view content.""" call_context_map = extract_content_open_ending( all_context, self.agent_prefix, self.agent_end, True ) for api_index, api_context in call_context_map.items(): api_status = self.plugin_status_map.get(api_context) if api_status is not None: if display_mode: all_context = self._deal_error_md_tags(all_context, api_context) if Status.FAILED.value == api_status.status: err_msg = api_status.err_msg all_context = all_context.replace( api_context, f'\nError:{err_msg}\n' + self.to_view_antv_vis(api_status), ) else: all_context = all_context.replace( api_context, self.to_view_antv_vis(api_status) ) else: all_context = self._deal_error_md_tags( all_context, api_context, False ) all_context = all_context.replace( api_context, self.to_view_text(api_status) ) else: # not ready api call view change now_time = datetime.now().timestamp() * 1000 cost = (now_time - self.start_time) / 1000 cost_str = "{:.2f}".format(cost) all_context = self._deal_error_md_tags(all_context, api_context) all_context = all_context.replace( api_context, f'\nWaiting...{cost_str}S\n', ) return all_context def update_from_context(self, all_context): """Modify the plugin status map based on the context.""" api_context_map: Dict[int, str] = extract_content( all_context, self.agent_prefix, self.agent_end, True ) for api_index, api_context in api_context_map.items(): api_context = api_context.replace("\\n", "").replace("\n", "") api_call_element = ET.fromstring(api_context) api_name = api_call_element.find("name").text if api_name.find("[") >= 0 or api_name.find("]") >= 0: api_name = api_name.replace("[", "").replace("]", "") api_args = {} args_elements = api_call_element.find("args") for child_element in args_elements.iter(): api_args[child_element.tag] = child_element.text api_status = self.plugin_status_map.get(api_context) if api_status is None: api_status = PluginStatus( name=api_name, location=[api_index], args=api_args ) self.plugin_status_map[api_context] = api_status else: api_status.location.append(api_index) def _to_view_param_str(self, api_status): param = {} if api_status.name: param["name"] = api_status.name param["status"] = api_status.status if api_status.logo_url: param["logo"] = api_status.logo_url if api_status.err_msg: param["err_msg"] = api_status.err_msg if api_status.api_result: param["result"] = api_status.api_result return json.dumps(param, default=serialize, ensure_ascii=False) def to_view_text(self, api_status: PluginStatus): """Return the view content.""" api_call_element = ET.Element("dbgpt-view") api_call_element.text = self._to_view_param_str(api_status) result = ET.tostring(api_call_element, encoding="utf-8") return result.decode("utf-8") def to_view_antv_vis(self, api_status: PluginStatus): """Return the vis content.""" if self.backend_rendering: html_table = api_status.df.to_html( index=False, escape=False, sparsify=False ) table_str = "".join(html_table.split()) table_str = table_str.replace("\n", " ") sql = api_status.args["sql"] html = ( f' \n
[SQL]{sql}
' f"{table_str}
\n " ) return html else: api_call_element = ET.Element("chart-view") api_call_element.attrib["content"] = self._to_antv_vis_param(api_status) api_call_element.text = "\n" result = ET.tostring(api_call_element, encoding="utf-8") return result.decode("utf-8") def _to_antv_vis_param(self, api_status: PluginStatus): param = {} if api_status.name: param["type"] = api_status.name if api_status.args: param["sql"] = api_status.args["sql"] data: Any = [] if api_status.api_result: data = api_status.api_result param["data"] = data return json.dumps(param, ensure_ascii=False) def run_display_sql(self, llm_text, sql_run_func): """Run the API calls for displaying SQL data.""" if self._is_need_wait_plugin_call( llm_text ) and self.check_last_plugin_call_ready(llm_text): # wait api call generate complete self.update_from_context(llm_text) for key, value in self.plugin_status_map.items(): if value.status == Status.TODO.value: value.status = Status.RUNNING.value logger.info(f"sql display execution:{value.name},{value.args}") try: sql = value.args["sql"] if sql: param = { "df": sql_run_func(sql), } value.df = param["df"] if self.display_registry.is_valid_command(value.name): value.api_result = self.display_registry.call( value.name, **param ) else: value.api_result = self.display_registry.call( "response_table", **param ) value.status = Status.COMPLETE.value except Exception as e: value.status = Status.FAILED.value value.err_msg = str(e) value.end_time = datetime.now().timestamp() * 1000 return self.api_view_context(llm_text, True) def display_sql_llmvis(self, llm_text, sql_run_func): """Render charts using the Antv standard protocol. Args: llm_text: LLM response text sql_run_func: sql run function Returns: ChartView protocol text """ try: if self._is_need_wait_plugin_call( llm_text ) and self.check_last_plugin_call_ready(llm_text): # wait api call generate complete self.update_from_context(llm_text) for key, value in self.plugin_status_map.items(): if value.status == Status.TODO.value: value.status = Status.RUNNING.value logger.info(f"SQL execution:{value.name},{value.args}") try: sql = value.args["sql"] if sql is not None and len(sql) > 0: data_df = sql_run_func(sql) value.df = data_df value.api_result = json.loads( data_df.to_json( orient="records", date_format="iso", date_unit="s", ) ) value.status = Status.COMPLETE.value else: value.status = Status.FAILED.value value.err_msg = "No executable sql!" except Exception as e: logger.error(f"data prepare exception!{str(e)}") value.status = Status.FAILED.value value.err_msg = str(e) value.end_time = datetime.now().timestamp() * 1000 except Exception as e: logger.error("Api parsing exception", e) raise ValueError("Api parsing exception," + str(e)) return self.api_view_context(llm_text, True) def display_only_sql_vis(self, chart: dict, sql_2_df_func): """Display the chart using the vis standard protocol.""" err_msg = None sql = chart.get("sql", None) try: param = {} df = sql_2_df_func(sql) if not sql or len(sql) <= 0: return None param["sql"] = sql param["type"] = chart.get("display_type", "response_table") param["title"] = chart.get("title", "") param["describe"] = chart.get("thought", "") param["data"] = json.loads( df.to_json(orient="records", date_format="iso", date_unit="s") ) view_json_str = json.dumps(param, default=serialize, ensure_ascii=False) except Exception as e: logger.error("parse_view_response error!" + str(e)) err_param = {"sql": f"{sql}", "type": "response_table", "data": []} err_msg = str(e) view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) # api_call_element.text = view_json_str result = f"```vis-chart\n{view_json_str}\n```" if err_msg: return f"""ERROR!{err_msg} \n {result}""" else: return result def display_dashboard_vis( self, charts: List[dict], sql_2_df_func, title: Optional[str] = None ): """Display the dashboard using the vis standard protocol.""" err_msg = None view_json_str = None chart_items = [] try: if not charts or len(charts) <= 0: return "Have no chart data!" for chart in charts: param = {} sql = chart.get("sql", "") param["sql"] = sql param["type"] = chart.get("display_type", "response_table") param["title"] = chart.get("title", "") param["describe"] = chart.get("thought", "") try: df = sql_2_df_func(sql) param["data"] = json.loads( df.to_json(orient="records", date_format="iso", date_unit="s") ) except Exception as e: param["data"] = [] param["err_msg"] = str(e) chart_items.append(param) dashboard_param = { "data": chart_items, "chart_count": len(chart_items), "title": title, "display_strategy": "default", "style": "default", } view_json_str = json.dumps( dashboard_param, default=serialize, ensure_ascii=False ) except Exception as e: logger.error("parse_view_response error!" + str(e)) return f"```error\nReport rendering exception!{str(e)}\n```" result = f"```vis-dashboard\n{view_json_str}\n```" if err_msg: return ( f"""\\n ERROR!{err_msg} \n {result}""" ) else: return result