refactor(agent): Refactor resource of agents (#1518)

This commit is contained in:
Fangyin Cheng
2024-05-15 09:57:19 +08:00
committed by GitHub
parent db4d318a5f
commit 559affe87d
102 changed files with 2633 additions and 2549 deletions

View File

@@ -2,14 +2,14 @@
import json
import logging
from typing import Optional
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
from dbgpt.vis.tags.vis_chart import Vis, VisChart
from ...core.action.base import Action, ActionOutput
from ...resource.resource_api import AgentResource, ResourceType
from ...resource.resource_db_api import ResourceDbClient
from ...resource.base import AgentResource, ResourceType
from ...resource.database import DBResource
logger = logging.getLogger(__name__)
@@ -69,34 +69,28 @@ class ChartAction(Action[SqlInput]):
content="The requested correctly structured answer could not be found.",
)
try:
if not self.resource_loader:
raise ValueError("ResourceLoader is not initialized")
resource_db_client: Optional[
ResourceDbClient
] = self.resource_loader.get_resource_api(
self.resource_need, ResourceDbClient
)
if not resource_db_client:
raise ValueError(
"There is no implementation class bound to database resource "
"execution"
)
if not resource:
raise ValueError("The data resource is not found")
data_df = await resource_db_client.query_to_df(resource.value, param.sql)
if not self.resource_need:
raise ValueError("The resource type is not found")
if not self.render_protocol:
raise ValueError("The rendering protocol is not initialized")
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
if not db_resources:
raise ValueError("The database resource is not found")
db = db_resources[0]
data_df = await db.query_to_df(param.sql)
view = await self.render_protocol.display(
chart=json.loads(model_to_json(param)), data_df=data_df
)
if not self.resource_need:
raise ValueError("The resource type is not found")
return ActionOutput(
is_exe_success=True,
content=model_to_json(param),
view=view,
resource_type=self.resource_need.value,
resource_value=resource.value,
resource_value=db._db_name,
)
except Exception as e:
logger.exception("Check your answers, the sql run failed")

View File

@@ -8,7 +8,7 @@ from dbgpt.util.utils import colored
from dbgpt.vis.tags.vis_code import Vis, VisCode
from ...core.action.base import Action, ActionOutput
from ...resource.resource_api import AgentResource
from ...resource.base import AgentResource
logger = logging.getLogger(__name__)

View File

@@ -8,8 +8,8 @@ from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
from ...core.action.base import Action, ActionOutput
from ...resource.resource_api import AgentResource, ResourceType
from ...resource.resource_db_api import ResourceDbClient
from ...resource.base import AgentResource, ResourceType
from ...resource.database import DBResource
logger = logging.getLogger(__name__)
@@ -83,29 +83,20 @@ class DashboardAction(Action[List[ChartItem]]):
)
chart_items: List[ChartItem] = input_param
try:
if not self.resource_loader:
raise ValueError("Resource loader is not initialized!")
resource_db_client: Optional[
ResourceDbClient
] = self.resource_loader.get_resource_api(
self.resource_need, ResourceDbClient
)
if not resource_db_client:
raise ValueError(
"There is no implementation class bound to database resource "
"execution"
)
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
if not db_resources:
raise ValueError("The database resource is not found")
if not resource:
raise ValueError("Resource is not initialized!")
db = db_resources[0]
if not db:
raise ValueError("The database resource is not found")
chart_params = []
for chart_item in chart_items:
chart_dict = {}
try:
sql_df = await resource_db_client.query_to_df(
resource.value, chart_item.sql
)
sql_df = await db.query_to_df(chart_item.sql)
chart_dict = chart_item.to_dict()
chart_dict["data"] = sql_df

View File

@@ -9,7 +9,7 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
from ...core.action.base import Action, ActionOutput
from ...core.schema import Status
from ...resource.resource_api import AgentResource, ResourceType
from ...resource.base import AgentResource, ResourceType
logger = logging.getLogger(__name__)

View File

@@ -9,14 +9,13 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
from ...core.action.base import Action, ActionOutput
from ...core.schema import Status
from ...plugin.generator import PluginPromptGenerator
from ...resource.resource_api import AgentResource, ResourceType
from ...resource.resource_plugin_api import ResourcePluginClient
from ...resource.base import AgentResource, ResourceType
from ...resource.tool.pack import ToolPack
logger = logging.getLogger(__name__)
class PluginInput(BaseModel):
class ToolInput(BaseModel):
"""Plugin input model."""
tool_name: str = Field(
@@ -32,8 +31,8 @@ class PluginInput(BaseModel):
thought: str = Field(..., description="Summary of thoughts to the user")
class PluginAction(Action[PluginInput]):
"""Plugin action class."""
class ToolAction(Action[ToolInput]):
"""Tool action class."""
def __init__(self):
"""Create a plugin action."""
@@ -43,7 +42,7 @@ class PluginAction(Action[PluginInput]):
@property
def resource_need(self) -> Optional[ResourceType]:
"""Return the resource type needed for the action."""
return ResourceType.Plugin
return ResourceType.Tool
@property
def render_protocol(self) -> Optional[Vis]:
@@ -53,19 +52,19 @@ class PluginAction(Action[PluginInput]):
@property
def out_model_type(self):
"""Return the output model type."""
return PluginInput
return ToolInput
@property
def ai_out_schema(self) -> Optional[str]:
"""Return the AI output schema."""
out_put_schema = {
"thought": "Summary of thoughts to the user",
"tool_name": "The name of a tool that can be used to answer the current "
"question or solve the current task.",
"args": {
"arg name1": "arg value1",
"arg name2": "arg value2",
},
"thought": "Summary of thoughts to the user",
}
return f"""Please response in the following json format:
@@ -92,13 +91,8 @@ class PluginAction(Action[PluginInput]):
need_vis_render (bool, optional): Whether need visualization rendering.
Defaults to True.
"""
plugin_generator: Optional[PluginPromptGenerator] = kwargs.get(
"plugin_generator", None
)
if not plugin_generator:
raise ValueError("No plugin generator found!")
try:
param: PluginInput = self._input_convert(ai_message, PluginInput)
param: ToolInput = self._input_convert(ai_message, ToolInput)
except Exception as e:
logger.exception((str(e)))
return ActionOutput(
@@ -107,21 +101,16 @@ class PluginAction(Action[PluginInput]):
)
try:
if not self.resource_loader:
raise ValueError("No resource_loader found!")
resource_plugin_client: Optional[
ResourcePluginClient
] = self.resource_loader.get_resource_api(
self.resource_need, ResourcePluginClient
)
if not resource_plugin_client:
raise ValueError("No implementation of the use of plug-in resources")
tool_packs = ToolPack.from_resource(self.resource)
if not tool_packs:
raise ValueError("The tool resource is not found")
tool_pack = tool_packs[0]
response_success = True
status = Status.RUNNING.value
err_msg = None
try:
tool_result = await resource_plugin_client.execute_command(
param.tool_name, param.args, plugin_generator
tool_result = await tool_pack.async_execute(
resource_name=param.tool_name, **param.args
)
status = Status.COMPLETE.value
except Exception as e:
@@ -146,9 +135,9 @@ class PluginAction(Action[PluginInput]):
return ActionOutput(
is_exe_success=response_success,
content=tool_result,
content=str(tool_result),
view=view,
observations=tool_result,
observations=str(tool_result),
)
except Exception as e:
logger.exception("Tool Action Run Failed")

View File

@@ -1,9 +1,11 @@
"""Dashboard Assistant Agent."""
from typing import List
from ..core.agent import AgentMessage
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..resource.resource_db_api import ResourceDbClient
from ..resource.database import DBResource
from .actions.dashboard_action import DashboardAction
@@ -58,15 +60,16 @@ class DashboardAssistantAgent(ConversableAgent):
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
reply_message = super()._init_reply_message(received_message)
client = self.not_null_resource_loader.get_resource_api(
self.actions[0].resource_need, ResourceDbClient
)
if not client:
dbs: List[DBResource] = DBResource.from_resource(self.resource)
if not dbs:
raise ValueError(
f"Resource type {self.actions[0].resource_need} is not supported."
)
db = dbs[0]
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": client.get_data_type(self.resources[0]),
"dialect": db.dialect,
}
return reply_message

View File

@@ -2,14 +2,13 @@
import json
import logging
from typing import Optional, Tuple, cast
from typing import List, Optional, Tuple, cast
from ..core.action.base import ActionOutput
from ..core.agent import AgentMessage
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..resource.resource_api import ResourceType
from ..resource.resource_db_api import ResourceDbClient
from ..resource.database import DBResource
from .actions.chart_action import ChartAction
logger = logging.getLogger(__name__)
@@ -74,18 +73,21 @@ class DataScientistAgent(ConversableAgent):
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
reply_message = super()._init_reply_message(received_message)
client = self.not_null_resource_loader.get_resource_api(
self.actions[0].resource_need, ResourceDbClient
)
if not client:
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": self.database.dialect,
}
return reply_message
@property
def database(self) -> DBResource:
"""Get the database resource."""
dbs: List[DBResource] = DBResource.from_resource(self.resource)
if not dbs:
raise ValueError(
f"Resource type {self.actions[0].resource_need} is not supported."
)
reply_message.context = {
"display_type": self.actions[0].render_prompt(),
"dialect": client.get_data_type(self.resources[0]),
}
return reply_message
return dbs[0]
async def correctness_check(
self, message: AgentMessage
@@ -112,17 +114,6 @@ class DataScientistAgent(ConversableAgent):
"generated is not found.",
)
try:
resource_db_client: Optional[
ResourceDbClient
] = self.not_null_resource_loader.get_resource_api(
ResourceType(action_out.resource_type), ResourceDbClient
)
if not resource_db_client:
return (
False,
"Please check your answer, the data resource type is not "
"supported.",
)
if not action_out.resource_value:
return (
False,
@@ -130,8 +121,9 @@ class DataScientistAgent(ConversableAgent):
"found.",
)
columns, values = await resource_db_client.query(
db=action_out.resource_value, sql=sql
columns, values = await self.database.query(
sql=sql,
db=action_out.resource_value,
)
if not values or len(values) <= 0:
return (

View File

@@ -0,0 +1 @@
"""Expand resources for the agent module."""

View File

@@ -0,0 +1,23 @@
"""Some internal tools for the DB-GPT project."""
from typing_extensions import Annotated, Doc
from ...resource.tool.base import tool
@tool(description="List the supported models in DB-GPT project.")
def list_dbgpt_support_models(
model_type: Annotated[
str, Doc("The model type, LLM(Large Language Model) and EMBEDDING).")
] = "LLM",
) -> str:
"""List the supported models in dbgpt."""
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
if model_type.lower() == "llm":
supports = list(LLM_MODEL_CONFIG.keys())
elif model_type.lower() == "embedding":
supports = list(EMBEDDING_MODEL_CONFIG.keys())
else:
raise ValueError(f"Unsupported model type: {model_type}")
return "\n\n".join(supports)

View File

@@ -0,0 +1,54 @@
"""Search tools for the agent."""
import re
from typing_extensions import Annotated, Doc
from ...resource.tool.base import tool
@tool(
description="Baidu search and return the results as a markdown string. Please set "
"number of results not less than 8 for rich search results.",
)
def baidu_search(
query: Annotated[str, Doc("The search query.")],
num_results: Annotated[int, Doc("The number of search results to return.")] = 8,
) -> str:
"""Baidu search and return the results as a markdown string.
Please set number of results not less than 8 for rich search results.
"""
import requests
from bs4 import BeautifulSoup
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:112.0) "
"Gecko/20100101 Firefox/112.0"
}
url = f"https://www.baidu.com/s?wd={query}&rn={num_results}"
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
search_results = []
for result in soup.find_all("div", class_=re.compile("^result c-container ")):
title = result.find("h3", class_="t").get_text()
link = result.find("a", href=True)["href"]
snippet = result.find("span", class_=re.compile("^content-right_"))
if snippet:
snippet = snippet.get_text()
else:
snippet = ""
search_results.append({"title": title, "href": link, "snippet": snippet})
return _search_to_view(search_results)
def _search_to_view(results) -> str:
view_results = []
for item in results:
view_results.append(
f"### [{item['title']}]({item['href']})\n{item['snippet']}\n"
)
return "\n".join(view_results)

View File

@@ -14,7 +14,7 @@ from ..core.action.base import Action, ActionOutput
from ..core.agent import Agent, AgentMessage, AgentReviewInfo
from ..core.base_agent import ConversableAgent
from ..core.profile import ProfileConfig
from ..resource.resource_api import AgentResource
from ..resource.base import AgentResource
from ..util.cmp import cmp_string_equal
try:

View File

@@ -1,22 +1,16 @@
"""Plugin Assistant Agent."""
import logging
from typing import Any, Dict, Optional
from ..core.base_agent import ConversableAgent
from ..core.profile import DynConfig, ProfileConfig
from ..plugin.generator import PluginPromptGenerator
from ..resource.resource_api import ResourceType
from ..resource.resource_plugin_api import ResourcePluginClient
from .actions.plugin_action import PluginAction
from .actions.tool_action import ToolAction
logger = logging.getLogger(__name__)
class PluginAssistantAgent(ConversableAgent):
"""Plugin Assistant Agent."""
plugin_generator: Optional[PluginPromptGenerator] = None
class ToolAssistantAgent(ConversableAgent):
"""Tool Assistant Agent."""
profile: ProfileConfig = ProfileConfig(
name=DynConfig(
@@ -57,37 +51,6 @@ class PluginAssistantAgent(ConversableAgent):
)
def __init__(self, **kwargs):
"""Create a new instance of PluginAssistantAgent."""
"""Create a new instance of ToolAssistantAgent."""
super().__init__(**kwargs)
self._init_actions([PluginAction])
# @property
# def introduce(self, **kwargs) -> str:
# """Introduce the agent."""
# if not self.plugin_generator:
# raise ValueError("PluginGenerator is not loaded.")
# return self.desc.format(
# tool_infos=self.plugin_generator.generate_commands_string()
# )
async def preload_resource(self):
"""Preload the resource."""
plugin_loader_client: ResourcePluginClient = (
self.not_null_resource_loader.get_resource_api(
ResourceType.Plugin, ResourcePluginClient
)
)
item_list = []
for item in self.resources:
if item.type == ResourceType.Plugin:
item_list.append(item.value)
plugin_generator = self.plugin_generator
for item in item_list:
plugin_generator = await plugin_loader_client.load_plugin(
item, plugin_generator
)
self.plugin_generator = plugin_generator
def prepare_act_param(self) -> Dict[str, Any]:
"""Prepare the act parameter."""
return {"plugin_generator": self.plugin_generator}
self._init_actions([ToolAction])