mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
@@ -2,14 +2,14 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
|
||||
from dbgpt.vis.tags.vis_chart import Vis, VisChart
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.database import DBResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,34 +69,28 @@ class ChartAction(Action[SqlInput]):
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("ResourceLoader is not initialized!")
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
raise ValueError(
|
||||
"There is no implementation class bound to database resource "
|
||||
"execution!"
|
||||
)
|
||||
if not resource:
|
||||
raise ValueError("The data resource is not found!")
|
||||
data_df = await resource_db_client.query_to_df(resource.value, param.sql)
|
||||
if not self.resource_need:
|
||||
raise ValueError("The resource type is not found!")
|
||||
|
||||
if not self.render_protocol:
|
||||
raise ValueError("The rendering protocol is not initialized!")
|
||||
|
||||
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not db_resources:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
db = db_resources[0]
|
||||
data_df = await db.query_to_df(param.sql)
|
||||
view = await self.render_protocol.display(
|
||||
chart=json.loads(model_to_json(param)), data_df=data_df
|
||||
)
|
||||
if not self.resource_need:
|
||||
raise ValueError("The resource type is not found!")
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=model_to_json(param),
|
||||
view=view,
|
||||
resource_type=self.resource_need.value,
|
||||
resource_value=resource.value,
|
||||
resource_value=db._db_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Check your answers, the sql run failed!")
|
||||
|
@@ -8,7 +8,7 @@ from dbgpt.util.utils import colored
|
||||
from dbgpt.vis.tags.vis_code import Vis, VisCode
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ...resource.base import AgentResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -8,8 +8,8 @@ from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.database import DBResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -83,29 +83,20 @@ class DashboardAction(Action[List[ChartItem]]):
|
||||
)
|
||||
chart_items: List[ChartItem] = input_param
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("Resource loader is not initialized!")
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
raise ValueError(
|
||||
"There is no implementation class bound to database resource "
|
||||
"execution!"
|
||||
)
|
||||
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not db_resources:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
if not resource:
|
||||
raise ValueError("Resource is not initialized!")
|
||||
db = db_resources[0]
|
||||
|
||||
if not db:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
chart_params = []
|
||||
for chart_item in chart_items:
|
||||
chart_dict = {}
|
||||
try:
|
||||
sql_df = await resource_db_client.query_to_df(
|
||||
resource.value, chart_item.sql
|
||||
)
|
||||
sql_df = await db.query_to_df(chart_item.sql)
|
||||
chart_dict = chart_item.to_dict()
|
||||
|
||||
chart_dict["data"] = sql_df
|
||||
|
@@ -9,7 +9,7 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -9,14 +9,13 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...plugin.generator import PluginPromptGenerator
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_plugin_api import ResourcePluginClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.tool.pack import ToolPack
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginInput(BaseModel):
|
||||
class ToolInput(BaseModel):
|
||||
"""Plugin input model."""
|
||||
|
||||
tool_name: str = Field(
|
||||
@@ -32,8 +31,8 @@ class PluginInput(BaseModel):
|
||||
thought: str = Field(..., description="Summary of thoughts to the user")
|
||||
|
||||
|
||||
class PluginAction(Action[PluginInput]):
|
||||
"""Plugin action class."""
|
||||
class ToolAction(Action[ToolInput]):
|
||||
"""Tool action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a plugin action."""
|
||||
@@ -43,7 +42,7 @@ class PluginAction(Action[PluginInput]):
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
"""Return the resource type needed for the action."""
|
||||
return ResourceType.Plugin
|
||||
return ResourceType.Tool
|
||||
|
||||
@property
|
||||
def render_protocol(self) -> Optional[Vis]:
|
||||
@@ -53,19 +52,19 @@ class PluginAction(Action[PluginInput]):
|
||||
@property
|
||||
def out_model_type(self):
|
||||
"""Return the output model type."""
|
||||
return PluginInput
|
||||
return ToolInput
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
out_put_schema = {
|
||||
"thought": "Summary of thoughts to the user",
|
||||
"tool_name": "The name of a tool that can be used to answer the current "
|
||||
"question or solve the current task.",
|
||||
"args": {
|
||||
"arg name1": "arg value1",
|
||||
"arg name2": "arg value2",
|
||||
},
|
||||
"thought": "Summary of thoughts to the user",
|
||||
}
|
||||
|
||||
return f"""Please response in the following json format:
|
||||
@@ -92,13 +91,8 @@ class PluginAction(Action[PluginInput]):
|
||||
need_vis_render (bool, optional): Whether need visualization rendering.
|
||||
Defaults to True.
|
||||
"""
|
||||
plugin_generator: Optional[PluginPromptGenerator] = kwargs.get(
|
||||
"plugin_generator", None
|
||||
)
|
||||
if not plugin_generator:
|
||||
raise ValueError("No plugin generator found!")
|
||||
try:
|
||||
param: PluginInput = self._input_convert(ai_message, PluginInput)
|
||||
param: ToolInput = self._input_convert(ai_message, ToolInput)
|
||||
except Exception as e:
|
||||
logger.exception((str(e)))
|
||||
return ActionOutput(
|
||||
@@ -107,21 +101,16 @@ class PluginAction(Action[PluginInput]):
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("No resource_loader found!")
|
||||
resource_plugin_client: Optional[
|
||||
ResourcePluginClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourcePluginClient
|
||||
)
|
||||
if not resource_plugin_client:
|
||||
raise ValueError("No implementation of the use of plug-in resources!")
|
||||
tool_packs = ToolPack.from_resource(self.resource)
|
||||
if not tool_packs:
|
||||
raise ValueError("The tool resource is not found!")
|
||||
tool_pack = tool_packs[0]
|
||||
response_success = True
|
||||
status = Status.RUNNING.value
|
||||
err_msg = None
|
||||
try:
|
||||
tool_result = await resource_plugin_client.execute_command(
|
||||
param.tool_name, param.args, plugin_generator
|
||||
tool_result = await tool_pack.async_execute(
|
||||
resource_name=param.tool_name, **param.args
|
||||
)
|
||||
status = Status.COMPLETE.value
|
||||
except Exception as e:
|
||||
@@ -146,9 +135,9 @@ class PluginAction(Action[PluginInput]):
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=response_success,
|
||||
content=tool_result,
|
||||
content=str(tool_result),
|
||||
view=view,
|
||||
observations=tool_result,
|
||||
observations=str(tool_result),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Tool Action Run Failed!")
|
@@ -1,9 +1,11 @@
|
||||
"""Dashboard Assistant Agent."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from ..resource.database import DBResource
|
||||
from .actions.dashboard_action import DashboardAction
|
||||
|
||||
|
||||
@@ -58,15 +60,16 @@ class DashboardAssistantAgent(ConversableAgent):
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
client = self.not_null_resource_loader.get_resource_api(
|
||||
self.actions[0].resource_need, ResourceDbClient
|
||||
)
|
||||
if not client:
|
||||
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
|
||||
if not dbs:
|
||||
raise ValueError(
|
||||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||||
)
|
||||
db = dbs[0]
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": client.get_data_type(self.resources[0]),
|
||||
"dialect": db.dialect,
|
||||
}
|
||||
return reply_message
|
||||
|
@@ -2,14 +2,13 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from ..core.action.base import ActionOutput
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from ..resource.database import DBResource
|
||||
from .actions.chart_action import ChartAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,18 +73,21 @@ class DataScientistAgent(ConversableAgent):
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
client = self.not_null_resource_loader.get_resource_api(
|
||||
self.actions[0].resource_need, ResourceDbClient
|
||||
)
|
||||
if not client:
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": self.database.dialect,
|
||||
}
|
||||
return reply_message
|
||||
|
||||
@property
|
||||
def database(self) -> DBResource:
|
||||
"""Get the database resource."""
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not dbs:
|
||||
raise ValueError(
|
||||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||||
)
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": client.get_data_type(self.resources[0]),
|
||||
}
|
||||
return reply_message
|
||||
return dbs[0]
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
@@ -112,17 +114,6 @@ class DataScientistAgent(ConversableAgent):
|
||||
"generated is not found.",
|
||||
)
|
||||
try:
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.not_null_resource_loader.get_resource_api(
|
||||
ResourceType(action_out.resource_type), ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
return (
|
||||
False,
|
||||
"Please check your answer, the data resource type is not "
|
||||
"supported.",
|
||||
)
|
||||
if not action_out.resource_value:
|
||||
return (
|
||||
False,
|
||||
@@ -130,8 +121,9 @@ class DataScientistAgent(ConversableAgent):
|
||||
"found.",
|
||||
)
|
||||
|
||||
columns, values = await resource_db_client.query(
|
||||
db=action_out.resource_value, sql=sql
|
||||
columns, values = await self.database.query(
|
||||
sql=sql,
|
||||
db=action_out.resource_value,
|
||||
)
|
||||
if not values or len(values) <= 0:
|
||||
return (
|
||||
|
1
dbgpt/agent/expand/resources/__init__.py
Normal file
1
dbgpt/agent/expand/resources/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Expand resources for the agent module."""
|
23
dbgpt/agent/expand/resources/dbgpt_tool.py
Normal file
23
dbgpt/agent/expand/resources/dbgpt_tool.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Some internal tools for the DB-GPT project."""
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from ...resource.tool.base import tool
|
||||
|
||||
|
||||
@tool(description="List the supported models in DB-GPT project.")
|
||||
def list_dbgpt_support_models(
|
||||
model_type: Annotated[
|
||||
str, Doc("The model type, LLM(Large Language Model) and EMBEDDING).")
|
||||
] = "LLM",
|
||||
) -> str:
|
||||
"""List the supported models in dbgpt."""
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
|
||||
|
||||
if model_type.lower() == "llm":
|
||||
supports = list(LLM_MODEL_CONFIG.keys())
|
||||
elif model_type.lower() == "embedding":
|
||||
supports = list(EMBEDDING_MODEL_CONFIG.keys())
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
return "\n\n".join(supports)
|
54
dbgpt/agent/expand/resources/search_tool.py
Normal file
54
dbgpt/agent/expand/resources/search_tool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Search tools for the agent."""
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from ...resource.tool.base import tool
|
||||
|
||||
|
||||
@tool(
|
||||
description="Baidu search and return the results as a markdown string. Please set "
|
||||
"number of results not less than 8 for rich search results.",
|
||||
)
|
||||
def baidu_search(
|
||||
query: Annotated[str, Doc("The search query.")],
|
||||
num_results: Annotated[int, Doc("The number of search results to return.")] = 8,
|
||||
) -> str:
|
||||
"""Baidu search and return the results as a markdown string.
|
||||
|
||||
Please set number of results not less than 8 for rich search results.
|
||||
"""
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:112.0) "
|
||||
"Gecko/20100101 Firefox/112.0"
|
||||
}
|
||||
url = f"https://www.baidu.com/s?wd={query}&rn={num_results}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
search_results = []
|
||||
for result in soup.find_all("div", class_=re.compile("^result c-container ")):
|
||||
title = result.find("h3", class_="t").get_text()
|
||||
link = result.find("a", href=True)["href"]
|
||||
snippet = result.find("span", class_=re.compile("^content-right_"))
|
||||
if snippet:
|
||||
snippet = snippet.get_text()
|
||||
else:
|
||||
snippet = ""
|
||||
search_results.append({"title": title, "href": link, "snippet": snippet})
|
||||
|
||||
return _search_to_view(search_results)
|
||||
|
||||
|
||||
def _search_to_view(results) -> str:
|
||||
view_results = []
|
||||
for item in results:
|
||||
view_results.append(
|
||||
f"### [{item['title']}]({item['href']})\n{item['snippet']}\n"
|
||||
)
|
||||
return "\n".join(view_results)
|
@@ -14,7 +14,7 @@ from ..core.action.base import Action, ActionOutput
|
||||
from ..core.agent import Agent, AgentMessage, AgentReviewInfo
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import ProfileConfig
|
||||
from ..resource.resource_api import AgentResource
|
||||
from ..resource.base import AgentResource
|
||||
from ..util.cmp import cmp_string_equal
|
||||
|
||||
try:
|
||||
|
@@ -1,22 +1,16 @@
|
||||
"""Plugin Assistant Agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_plugin_api import ResourcePluginClient
|
||||
from .actions.plugin_action import PluginAction
|
||||
from .actions.tool_action import ToolAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginAssistantAgent(ConversableAgent):
|
||||
"""Plugin Assistant Agent."""
|
||||
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
class ToolAssistantAgent(ConversableAgent):
|
||||
"""Tool Assistant Agent."""
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
@@ -57,37 +51,6 @@ class PluginAssistantAgent(ConversableAgent):
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new instance of PluginAssistantAgent."""
|
||||
"""Create a new instance of ToolAssistantAgent."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PluginAction])
|
||||
|
||||
# @property
|
||||
# def introduce(self, **kwargs) -> str:
|
||||
# """Introduce the agent."""
|
||||
# if not self.plugin_generator:
|
||||
# raise ValueError("PluginGenerator is not loaded.")
|
||||
# return self.desc.format(
|
||||
# tool_infos=self.plugin_generator.generate_commands_string()
|
||||
# )
|
||||
|
||||
async def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
plugin_loader_client: ResourcePluginClient = (
|
||||
self.not_null_resource_loader.get_resource_api(
|
||||
ResourceType.Plugin, ResourcePluginClient
|
||||
)
|
||||
)
|
||||
item_list = []
|
||||
for item in self.resources:
|
||||
if item.type == ResourceType.Plugin:
|
||||
item_list.append(item.value)
|
||||
plugin_generator = self.plugin_generator
|
||||
for item in item_list:
|
||||
plugin_generator = await plugin_loader_client.load_plugin(
|
||||
item, plugin_generator
|
||||
)
|
||||
self.plugin_generator = plugin_generator
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
"""Prepare the act parameter."""
|
||||
return {"plugin_generator": self.plugin_generator}
|
||||
self._init_actions([ToolAction])
|
Reference in New Issue
Block a user