mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(agent): Multi agents v0.1 (#1044)
Co-authored-by: qidanrui <qidanrui@gmail.com> Co-authored-by: csunny <cfqsunny@163.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from .tags.vis_chart import VisChart
|
||||
from .tags.vis_code import VisCode
|
||||
from .tags.vis_dashboard import VisDashboard
|
||||
from .tags.vis_agent_plans import VisAgentPlans
|
||||
from .tags.vis_agent_message import VisAgentMessages
|
||||
from .tags.vis_plugin import VisPlugin
|
||||
from .client import vis_client
|
||||
|
28
dbgpt/vis/base.py
Normal file
28
dbgpt/vis/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from dbgpt.util.json_utils import serialize
|
||||
|
||||
|
||||
class Vis:
|
||||
@abstractmethod
|
||||
async def generate_param(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Display corresponding content using vis protocol
|
||||
Args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
vis protocol text
|
||||
"""
|
||||
|
||||
async def disply(self, **kwargs) -> Optional[str]:
|
||||
return f"```{self.vis_tag()}\n{json.dumps(await self.generate_param(**kwargs), default=serialize, ensure_ascii=False)}\n```"
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls) -> str:
|
||||
"""
|
||||
Current vis protocol module tag name
|
||||
Returns:
|
||||
|
||||
"""
|
34
dbgpt/vis/client.py
Normal file
34
dbgpt/vis/client.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from .tags.vis_code import VisCode
|
||||
from .tags.vis_chart import VisChart
|
||||
from .tags.vis_dashboard import VisDashboard
|
||||
from .tags.vis_agent_plans import VisAgentPlans
|
||||
from .tags.vis_agent_message import VisAgentMessages
|
||||
from .tags.vis_plugin import VisPlugin
|
||||
from .base import Vis
|
||||
|
||||
|
||||
class VisClient:
|
||||
def __init__(self):
|
||||
self._vis_tag: Dict[str, Vis] = {}
|
||||
|
||||
def register(self, vis_cls: Vis):
|
||||
self._vis_tag[vis_cls.vis_tag()] = vis_cls()
|
||||
|
||||
def get(self, tag_name):
|
||||
if tag_name not in self._vis_tag:
|
||||
raise ValueError(f"Vis protocol tags not yet supported![{tag_name}]")
|
||||
return self._vis_tag[tag_name]
|
||||
|
||||
def tag_names(self):
|
||||
self._vis_tag.keys()
|
||||
|
||||
|
||||
vis_client = VisClient()
|
||||
|
||||
vis_client.register(VisCode)
|
||||
vis_client.register(VisChart)
|
||||
vis_client.register(VisDashboard)
|
||||
vis_client.register(VisAgentPlans)
|
||||
vis_client.register(VisAgentMessages)
|
||||
vis_client.register(VisPlugin)
|
0
dbgpt/vis/tags/__init__.py
Normal file
0
dbgpt/vis/tags/__init__.py
Normal file
17
dbgpt/vis/tags/vis_agent_message.py
Normal file
17
dbgpt/vis/tags/vis_agent_message.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
class VisAgentMessages(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
param = {
|
||||
"sender": kwargs["sender"],
|
||||
"receiver": kwargs["receiver"],
|
||||
"model": kwargs["model"],
|
||||
"markdown": kwargs.get("markdown", None),
|
||||
}
|
||||
return param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-agent-messages"
|
18
dbgpt/vis/tags/vis_agent_plans.py
Normal file
18
dbgpt/vis/tags/vis_agent_plans.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
class VisAgentPlans(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
param = {
|
||||
"name": kwargs["name"],
|
||||
"num": kwargs["sub_task_num"],
|
||||
"status": kwargs["status"],
|
||||
"agent": kwargs.get("sub_task_agent", None),
|
||||
"markdown": kwargs.get("markdown", None),
|
||||
}
|
||||
return param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-agent-plans"
|
72
dbgpt/vis/tags/vis_chart.py
Normal file
72
dbgpt/vis/tags/vis_chart.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Optional
|
||||
import yaml
|
||||
import json
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
def default_chart_type_promot() -> str:
|
||||
"""this function is moved from excel_analyze/chat.py,and used by subclass.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
antv_charts = [
|
||||
{"response_line_chart": "used to display comparative trend analysis data"},
|
||||
{
|
||||
"response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
|
||||
},
|
||||
{
|
||||
"response_table": "suitable for display with many display columns or non-numeric columns"
|
||||
},
|
||||
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
|
||||
{
|
||||
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
|
||||
},
|
||||
{
|
||||
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
|
||||
},
|
||||
{
|
||||
"response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
|
||||
},
|
||||
{
|
||||
"response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
|
||||
},
|
||||
{
|
||||
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
|
||||
},
|
||||
]
|
||||
return "\n".join(
|
||||
f"{key}:{value}"
|
||||
for dict_item in antv_charts
|
||||
for key, value in dict_item.items()
|
||||
)
|
||||
|
||||
|
||||
class VisChart(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
chart = kwargs.get("chart", None)
|
||||
sql_2_df_func = kwargs.get("sql_2_df_func", None)
|
||||
|
||||
if not chart or not sql_2_df_func:
|
||||
raise ValueError(
|
||||
f"Parameter information is missing and {self.vis_tag} protocol conversion cannot be performed."
|
||||
)
|
||||
|
||||
sql = chart.get("sql", None)
|
||||
param = {}
|
||||
df = sql_2_df_func(sql)
|
||||
if not sql or len(sql) <= 0:
|
||||
return None
|
||||
|
||||
param["sql"] = sql
|
||||
param["type"] = chart.get("display_type", "response_table")
|
||||
param["title"] = chart.get("title", "")
|
||||
param["describe"] = chart.get("thought", "")
|
||||
|
||||
param["data"] = json.loads(
|
||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||
)
|
||||
return param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-chart"
|
17
dbgpt/vis/tags/vis_code.py
Normal file
17
dbgpt/vis/tags/vis_code.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
class VisCode(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
param = {
|
||||
"exit_success": kwargs["exit_success"],
|
||||
"language": kwargs["language"],
|
||||
"code": kwargs["code"],
|
||||
"log": kwargs.get("log", None),
|
||||
}
|
||||
return param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-code"
|
48
dbgpt/vis/tags/vis_dashboard.py
Normal file
48
dbgpt/vis/tags/vis_dashboard.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
class VisDashboard(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
charts = kwargs.get("charts", None)
|
||||
sql_2_df_func = kwargs.get("sql_2_df_func", None)
|
||||
title = kwargs.get("title", None)
|
||||
if not charts or not sql_2_df_func or not title:
|
||||
raise ValueError(
|
||||
f"Parameter information is missing and {self.vis_tag} protocol conversion cannot be performed."
|
||||
)
|
||||
|
||||
chart_items = []
|
||||
if not charts or len(charts) <= 0:
|
||||
return f"""Have no chart data!"""
|
||||
for chart in charts:
|
||||
param = {}
|
||||
sql = chart.get("sql", "")
|
||||
param["sql"] = sql
|
||||
param["type"] = chart.get("display_type", "response_table")
|
||||
param["title"] = chart.get("title", "")
|
||||
param["describe"] = chart.get("thought", "")
|
||||
try:
|
||||
df = sql_2_df_func(sql)
|
||||
param["data"] = json.loads(
|
||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||
)
|
||||
except Exception as e:
|
||||
param["data"] = []
|
||||
param["err_msg"] = str(e)
|
||||
chart_items.append(param)
|
||||
|
||||
dashboard_param = {
|
||||
"data": chart_items,
|
||||
"chart_count": len(chart_items),
|
||||
"title": title,
|
||||
"display_strategy": "default",
|
||||
"style": "default",
|
||||
}
|
||||
|
||||
return dashboard_param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-dashboard"
|
18
dbgpt/vis/tags/vis_plugin.py
Normal file
18
dbgpt/vis/tags/vis_plugin.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
from ..base import Vis
|
||||
|
||||
|
||||
class VisPlugin(Vis):
|
||||
async def generate_content(self, **kwargs) -> Optional[str]:
|
||||
param = {
|
||||
"name": kwargs["name"],
|
||||
"status": kwargs["status"],
|
||||
"logo": kwargs.get("logo", None),
|
||||
"result": kwargs.get("result", None),
|
||||
"err_msg": kwargs.get("err_msg", None),
|
||||
}
|
||||
return param
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
return "vis-plugin"
|
Reference in New Issue
Block a user