mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
397
dbgpt/agent/util/api_call.py
Normal file
397
dbgpt/agent/util/api_call.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Module for managing commands and command plugins."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.util.json_utils import serialize
|
||||
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginStatus(BaseModel):
|
||||
"""A class representing the status of a plugin."""
|
||||
|
||||
name: str
|
||||
location: List[int]
|
||||
args: dict
|
||||
status: Union[Status, str] = Status.TODO.value
|
||||
logo_url: Optional[str] = None
|
||||
api_result: Optional[str] = None
|
||||
err_msg: Optional[str] = None
|
||||
start_time: float = datetime.now().timestamp() * 1000
|
||||
end_time: Optional[str] = None
|
||||
|
||||
df: Any = None
|
||||
|
||||
|
||||
class ApiCall:
|
||||
"""A class representing an API call."""
|
||||
|
||||
agent_prefix = "<api-call>"
|
||||
agent_end = "</api-call>"
|
||||
name_prefix = "<name>"
|
||||
name_end = "</name>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_generator: Any = None,
|
||||
display_registry: Any = None,
|
||||
backend_rendering: bool = False,
|
||||
):
|
||||
"""Create a new ApiCall object."""
|
||||
self.plugin_status_map: Dict[str, PluginStatus] = {}
|
||||
|
||||
self.plugin_generator = plugin_generator
|
||||
self.display_registry = display_registry
|
||||
self.start_time = datetime.now().timestamp() * 1000
|
||||
self.backend_rendering: bool = backend_rendering
|
||||
|
||||
def _is_need_wait_plugin_call(self, api_call_context):
|
||||
start_agent_count = api_call_context.count(self.agent_prefix)
|
||||
|
||||
if start_agent_count > 0:
|
||||
return True
|
||||
else:
|
||||
# Check the new character at the end
|
||||
check_len = len(self.agent_prefix)
|
||||
last_text = api_call_context[-check_len:]
|
||||
for i in range(check_len):
|
||||
text_tmp = last_text[-i:]
|
||||
prefix_tmp = self.agent_prefix[:i]
|
||||
if text_tmp == prefix_tmp:
|
||||
return True
|
||||
else:
|
||||
i += 1
|
||||
return False
|
||||
|
||||
def check_last_plugin_call_ready(self, all_context):
|
||||
"""Check if the last plugin call is ready."""
|
||||
start_agent_count = all_context.count(self.agent_prefix)
|
||||
end_agent_count = all_context.count(self.agent_end)
|
||||
|
||||
if start_agent_count > 0 and start_agent_count == end_agent_count:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
|
||||
error_md_tags = [
|
||||
"```",
|
||||
"```python",
|
||||
"```xml",
|
||||
"```json",
|
||||
"```markdown",
|
||||
"```sql",
|
||||
]
|
||||
if not include_end:
|
||||
md_tag_end = ""
|
||||
else:
|
||||
md_tag_end = "```"
|
||||
for tag in error_md_tags:
|
||||
all_context = all_context.replace(
|
||||
tag + api_context + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + "\n" + api_context + "\n" + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + " " + api_context + " " + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
return all_context
|
||||
|
||||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||||
"""Return the view content."""
|
||||
call_context_map = extract_content_open_ending(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in call_context_map.items():
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is not None:
|
||||
if display_mode:
|
||||
all_context = self._deal_error_md_tags(all_context, api_context)
|
||||
if Status.FAILED.value == api_status.status:
|
||||
err_msg = api_status.err_msg
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:red">Error:</span>{err_msg}\n'
|
||||
+ self.to_view_antv_vis(api_status),
|
||||
)
|
||||
else:
|
||||
all_context = all_context.replace(
|
||||
api_context, self.to_view_antv_vis(api_status)
|
||||
)
|
||||
|
||||
else:
|
||||
all_context = self._deal_error_md_tags(
|
||||
all_context, api_context, False
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context, self.to_view_text(api_status)
|
||||
)
|
||||
|
||||
else:
|
||||
# not ready api call view change
|
||||
now_time = datetime.now().timestamp() * 1000
|
||||
cost = (now_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
all_context = self._deal_error_md_tags(all_context, api_context)
|
||||
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
|
||||
)
|
||||
|
||||
return all_context
|
||||
|
||||
def update_from_context(self, all_context):
|
||||
"""Modify the plugin status map based on the context."""
|
||||
api_context_map: Dict[int, str] = extract_content(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in api_context_map.items():
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
api_call_element = ET.fromstring(api_context)
|
||||
api_name = api_call_element.find("name").text
|
||||
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
|
||||
api_name = api_name.replace("[", "").replace("]", "")
|
||||
api_args = {}
|
||||
args_elements = api_call_element.find("args")
|
||||
for child_element in args_elements.iter():
|
||||
api_args[child_element.tag] = child_element.text
|
||||
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is None:
|
||||
api_status = PluginStatus(
|
||||
name=api_name, location=[api_index], args=api_args
|
||||
)
|
||||
self.plugin_status_map[api_context] = api_status
|
||||
else:
|
||||
api_status.location.append(api_index)
|
||||
|
||||
def _to_view_param_str(self, api_status):
|
||||
param = {}
|
||||
if api_status.name:
|
||||
param["name"] = api_status.name
|
||||
param["status"] = api_status.status
|
||||
if api_status.logo_url:
|
||||
param["logo"] = api_status.logo_url
|
||||
|
||||
if api_status.err_msg:
|
||||
param["err_msg"] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param["result"] = api_status.api_result
|
||||
|
||||
return json.dumps(param, default=serialize, ensure_ascii=False)
|
||||
|
||||
def to_view_text(self, api_status: PluginStatus):
|
||||
"""Return the view content."""
|
||||
api_call_element = ET.Element("dbgpt-view")
|
||||
api_call_element.text = self._to_view_param_str(api_status)
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
|
||||
def to_view_antv_vis(self, api_status: PluginStatus):
|
||||
"""Return the vis content."""
|
||||
if self.backend_rendering:
|
||||
html_table = api_status.df.to_html(
|
||||
index=False, escape=False, sparsify=False
|
||||
)
|
||||
table_str = "".join(html_table.split())
|
||||
table_str = table_str.replace("\n", " ")
|
||||
sql = api_status.args["sql"]
|
||||
html = (
|
||||
f' \n<div><b>[SQL]{sql}</b></div><div class="w-full overflow-auto">'
|
||||
f"{table_str}</div>\n "
|
||||
)
|
||||
return html
|
||||
else:
|
||||
api_call_element = ET.Element("chart-view")
|
||||
api_call_element.attrib["content"] = self._to_antv_vis_param(api_status)
|
||||
api_call_element.text = "\n"
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
|
||||
def _to_antv_vis_param(self, api_status: PluginStatus):
|
||||
param = {}
|
||||
if api_status.name:
|
||||
param["type"] = api_status.name
|
||||
if api_status.args:
|
||||
param["sql"] = api_status.args["sql"]
|
||||
|
||||
data: Any = []
|
||||
if api_status.api_result:
|
||||
data = api_status.api_result
|
||||
param["data"] = data
|
||||
return json.dumps(param, ensure_ascii=False)
|
||||
|
||||
def run_display_sql(self, llm_text, sql_run_func):
|
||||
"""Run the API calls for displaying SQL data."""
|
||||
if self._is_need_wait_plugin_call(
|
||||
llm_text
|
||||
) and self.check_last_plugin_call_ready(llm_text):
|
||||
# wait api call generate complete
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logger.info(f"sql display execution:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args["sql"]
|
||||
if sql:
|
||||
param = {
|
||||
"df": sql_run_func(sql),
|
||||
}
|
||||
value.df = param["df"]
|
||||
if self.display_registry.is_valid_command(value.name):
|
||||
value.api_result = self.display_registry.call(
|
||||
value.name, **param
|
||||
)
|
||||
else:
|
||||
value.api_result = self.display_registry.call(
|
||||
"response_table", **param
|
||||
)
|
||||
|
||||
value.status = Status.COMPLETE.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text, True)
|
||||
|
||||
def display_sql_llmvis(self, llm_text, sql_run_func):
|
||||
"""Render charts using the Antv standard protocol.
|
||||
|
||||
Args:
|
||||
llm_text: LLM response text
|
||||
sql_run_func: sql run function
|
||||
|
||||
Returns:
|
||||
ChartView protocol text
|
||||
"""
|
||||
try:
|
||||
if self._is_need_wait_plugin_call(
|
||||
llm_text
|
||||
) and self.check_last_plugin_call_ready(llm_text):
|
||||
# wait api call generate complete
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logger.info(f"SQL execution:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args["sql"]
|
||||
if sql is not None and len(sql) > 0:
|
||||
data_df = sql_run_func(sql)
|
||||
value.df = data_df
|
||||
value.api_result = json.loads(
|
||||
data_df.to_json(
|
||||
orient="records",
|
||||
date_format="iso",
|
||||
date_unit="s",
|
||||
)
|
||||
)
|
||||
value.status = Status.COMPLETE.value
|
||||
else:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = "No executable sql!"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"data prepare exception!{str(e)}")
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
except Exception as e:
|
||||
logger.error("Api parsing exception", e)
|
||||
raise ValueError("Api parsing exception," + str(e))
|
||||
|
||||
return self.api_view_context(llm_text, True)
|
||||
|
||||
def display_only_sql_vis(self, chart: dict, sql_2_df_func):
|
||||
"""Display the chart using the vis standard protocol."""
|
||||
err_msg = None
|
||||
sql = chart.get("sql", None)
|
||||
try:
|
||||
param = {}
|
||||
df = sql_2_df_func(sql)
|
||||
if not sql or len(sql) <= 0:
|
||||
return None
|
||||
|
||||
param["sql"] = sql
|
||||
param["type"] = chart.get("display_type", "response_table")
|
||||
param["title"] = chart.get("title", "")
|
||||
param["describe"] = chart.get("thought", "")
|
||||
|
||||
param["data"] = json.loads(
|
||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||
)
|
||||
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("parse_view_response error!" + str(e))
|
||||
err_param = {"sql": f"{sql}", "type": "response_table", "data": []}
|
||||
err_msg = str(e)
|
||||
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
||||
|
||||
# api_call_element.text = view_json_str
|
||||
result = f"```vis-chart\n{view_json_str}\n```"
|
||||
if err_msg:
|
||||
return f"""<span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
|
||||
else:
|
||||
return result
|
||||
|
||||
def display_dashboard_vis(
|
||||
self, charts: List[dict], sql_2_df_func, title: Optional[str] = None
|
||||
):
|
||||
"""Display the dashboard using the vis standard protocol."""
|
||||
err_msg = None
|
||||
view_json_str = None
|
||||
|
||||
chart_items = []
|
||||
try:
|
||||
if not charts or len(charts) <= 0:
|
||||
return "Have no chart data!"
|
||||
for chart in charts:
|
||||
param = {}
|
||||
sql = chart.get("sql", "")
|
||||
param["sql"] = sql
|
||||
param["type"] = chart.get("display_type", "response_table")
|
||||
param["title"] = chart.get("title", "")
|
||||
param["describe"] = chart.get("thought", "")
|
||||
try:
|
||||
df = sql_2_df_func(sql)
|
||||
param["data"] = json.loads(
|
||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||
)
|
||||
except Exception as e:
|
||||
param["data"] = []
|
||||
param["err_msg"] = str(e)
|
||||
chart_items.append(param)
|
||||
|
||||
dashboard_param = {
|
||||
"data": chart_items,
|
||||
"chart_count": len(chart_items),
|
||||
"title": title,
|
||||
"display_strategy": "default",
|
||||
"style": "default",
|
||||
}
|
||||
view_json_str = json.dumps(
|
||||
dashboard_param, default=serialize, ensure_ascii=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("parse_view_response error!" + str(e))
|
||||
return f"```error\nReport rendering exception!{str(e)}\n```"
|
||||
|
||||
result = f"```vis-dashboard\n{view_json_str}\n```"
|
||||
if err_msg:
|
||||
return (
|
||||
f"""\\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result}"""
|
||||
)
|
||||
else:
|
||||
return result
|
Reference in New Issue
Block a user