feat(agent): Compatible with MCP and tools (#2566)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Fangyin Cheng 2025-04-01 16:03:52 +08:00 committed by GitHub
parent 9719c0abcd
commit 59e34e2d89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 160 additions and 72 deletions

View File

@ -3,6 +3,7 @@ import logging
from typing import Optional
from dbgpt.agent import Action, ActionOutput, AgentResource, Resource, ResourceType
from dbgpt.util.json_utils import parse_or_raise_error
from ...resource.tool.base import BaseTool, ToolParameter
from ...util.react_parser import ReActOutputParser, ReActStep
@ -166,8 +167,9 @@ class ReActAction(ToolAction):
action_input = parsed_step.action_input
action_input_str = action_input
try:
# Try to parse the action input to dict
if action_input and isinstance(action_input, str):
tool_args = json.loads(action_input)
tool_args = parse_or_raise_error(action_input)
elif isinstance(action_input, dict):
tool_args = action_input
action_input_str = json.dumps(action_input, ensure_ascii=False)
@ -181,6 +183,7 @@ class ReActAction(ToolAction):
self.resource,
self.render_protocol,
need_vis_render=need_vis_render,
raw_tool_input=action_input_str,
)
if not act_out.action_input:
act_out.action_input = action_input_str

View File

@ -114,6 +114,7 @@ async def run_tool(
resource: Resource,
render_protocol: Optional[Vis] = None,
need_vis_render: bool = False,
raw_tool_input: Optional[str] = None,
) -> ActionOutput:
"""Run the tool."""
is_terminal = None
@ -121,10 +122,22 @@ async def run_tool(
tool_packs = ToolPack.from_resource(resource)
if not tool_packs:
raise ValueError("The tool resource is not found")
tool_pack = tool_packs[0]
tool_pack: ToolPack = tool_packs[0]
response_success = True
status = Status.RUNNING.value
err_msg = None
if raw_tool_input and tool_pack.parse_execute_args(
resource_name=name, input_str=raw_tool_input
):
# Use real tool to parse the input, it will raise raw error when failed
# it will make agent to modify the input and retry
parsed_args = tool_pack.parse_execute_args(
resource_name=name, input_str=raw_tool_input
)
if parsed_args and isinstance(parsed_args, tuple):
args = parsed_args[1]
try:
tool_result = await tool_pack.async_execute(resource_name=name, **args)
status = Status.COMPLETE.value

View File

@ -31,6 +31,11 @@ T = TypeVar("T", bound="Resource")
_DEFAULT_RESOURCE_NAME = _("My Agent Resource")
_DEFAULT_RESOURCE_NAME_DESCRIPTION = _("Resource name")
ARGS_TYPE = Tuple[Any, ...]
KWARGS_TYPE = Dict[str, Any]
EXECUTE_ARGS_TYPE = Tuple[ARGS_TYPE, KWARGS_TYPE]
PARSE_EXECUTE_ARGS_FUNCTION = Callable[[Optional[str]], Optional[EXECUTE_ARGS_TYPE]]
class ResourceType(str, Enum):
"""Resource type enumeration."""
@ -190,6 +195,31 @@ class Resource(ABC, Generic[P]):
"""Get the resources."""
raise NotImplementedError
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Try to parse the execute arguments.
If return None, the execute parameters are pass by raw arguments generated by
Agent.
What case you should implement this method:
1. If the resource has a specific input format, you should implement this
2. You want parse and raise the raw error information, you should implement this
Args:
resource_name(str): The resource name.
input_str(str): The input string, it is the raw input string generated by
Agent.
Returns:
Optional[EXECUTE_ARGS_TYPE]: The execute arguments.
Raises:
ValueError: If the input string is invalid.
"""
return None
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
"""Execute the resource."""
raise NotImplementedError

View File

@ -49,6 +49,11 @@ class ResourcePack(Resource[PackResourceParameters]):
"""Get the resource by name."""
return self._resources.get(name, None)
async def preload_resource(self):
"""Preload the resource."""
for sub_resource in self.sub_resources:
await sub_resource.preload_resource()
async def get_prompt(
self,
*,

View File

@ -12,7 +12,13 @@ from dbgpt._private.pydantic import BaseModel, Field, model_validator
from dbgpt.util.configure.base import _MISSING, _MISSING_TYPE
from dbgpt.util.function_utils import parse_param_description, type_to_string
from ..base import Resource, ResourceParameters, ResourceType
from ..base import (
EXECUTE_ARGS_TYPE,
PARSE_EXECUTE_ARGS_FUNCTION,
Resource,
ResourceParameters,
ResourceType,
)
ToolFunc = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]
@ -144,6 +150,7 @@ class FunctionTool(BaseTool):
description: Optional[str] = None,
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
args_schema: Optional[Type[BaseModel]] = None,
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
):
"""Create a tool from a function."""
if not description:
@ -155,6 +162,7 @@ class FunctionTool(BaseTool):
self._args: Dict[str, ToolParameter] = _parse_args(func, args, args_schema)
self._func = func
self._is_async = asyncio.iscoroutinefunction(func)
self._parse_execute_args_func = parse_execute_args_func
@property
def name(self) -> str:
@ -176,6 +184,14 @@ class FunctionTool(BaseTool):
"""Return whether the tool is asynchronous."""
return self._is_async
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
if self._parse_execute_args_func is not None:
return self._parse_execute_args_func(input_str)
return None
def execute(
self,
*args,

View File

@ -1,17 +1,22 @@
"""Tool resource pack module."""
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
from mcp import ClientSession
from mcp.client.sse import sse_client
from ..base import ResourceType, T
from dbgpt.util.json_utils import parse_or_raise_error
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
from ..pack import Resource, ResourcePack
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
from .exceptions import ToolExecutionException, ToolNotFoundException
ToolResourceType = Union[BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
ToolResourceType = Union[Resource, BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
logger = logging.getLogger(__name__)
def _is_function_tool(resources: Any) -> bool:
@ -28,22 +33,40 @@ def _is_tool(resources: Any) -> bool:
return isinstance(resources, BaseTool) or _is_function_tool(resources)
def _to_tool_list(resources: ToolResourceType) -> List[BaseTool]:
if isinstance(resources, BaseTool):
return [resources]
elif isinstance(resources, Sequence) and all(_is_tool(r) for r in resources):
new_resources = []
for r in resources:
if isinstance(r, BaseTool):
new_resources.append(r)
else:
function_tool = cast(FunctionTool, getattr(r, "_tool"))
new_resources.append(function_tool)
return new_resources
elif _is_function_tool(resources):
function_tool = cast(FunctionTool, getattr(resources, "_tool"))
return [function_tool]
raise ValueError("Invalid tool resource type")
def _to_tool_list(
resources: ToolResourceType, unpack: bool = False, ignore_error: bool = False
) -> List[Resource]:
def parse_tool(r):
if isinstance(r, BaseTool):
return [r]
elif _is_function_tool(r):
return [cast(FunctionTool, getattr(r, "_tool"))]
elif isinstance(r, ResourcePack):
if not unpack:
return [r]
new_list = []
for p in r.sub_resources:
new_list.extend(parse_tool(p))
return new_list
elif isinstance(r, Sequence):
new_list = []
for t in r:
new_list.extend(parse_tool(t))
return new_list
elif ignore_error:
return []
else:
raise ValueError("Invalid tool resource type")
return parse_tool(resources)
def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
# The position arguments is empty
args = ()
kwargs = parse_or_raise_error(input_str)
return args, kwargs
class ToolPack(ResourcePack):
@ -65,11 +88,7 @@ class ToolPack(ResourcePack):
"""Create a resource from another resource."""
if not resource:
return []
if isinstance(resource, ToolPack):
return [cast(T, resource)]
tools = super().from_resource(resource, ResourceType.Tool)
if not tools:
return []
tools = _to_tool_list(resource, unpack=True, ignore_error=True)
typed_tools = [cast(BaseTool, t) for t in tools]
return [ToolPack(typed_tools)] # type: ignore
@ -79,6 +98,7 @@ class ToolPack(ResourcePack):
command_name: str,
args: Optional[Dict[str, Any]] = None,
function: Optional[Callable] = None,
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
) -> None:
"""Add a command to the commands.
@ -93,6 +113,8 @@ class ToolPack(ResourcePack):
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
parse_execute_args (callable, optional): A callable function to parse the
execute arguments. Defaults to None.
"""
if args is not None:
tool_args = {}
@ -124,6 +146,7 @@ class ToolPack(ResourcePack):
func=function,
args=tool_args,
description=command_label,
parse_execute_args_func=parse_execute_args_func,
)
self.append(ft)
@ -143,6 +166,16 @@ class ToolPack(ResourcePack):
del arguments[arg_name]
return arguments
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
try:
tl = self._get_execution_tool(resource_name)
return tl.parse_execute_args(input_str=input_str)
except ToolNotFoundException:
return None
def execute(
self,
*args,
@ -244,17 +277,6 @@ class AutoGPTPluginToolPack(ToolPack):
self._loaded = True
async def call_mcp_tool(server, tool_name, args: dict):
try:
async with sse_client(url=server) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
return await session.call_tool(tool_name, arguments=args)
except Exception:
raise ValueError("MCP Call Exception!{str(e)}")
class MCPToolPack(ToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
"""Create an Auto-GPT plugin tool pack."""
@ -263,42 +285,11 @@ class MCPToolPack(ToolPack):
self._loaded = False
self.tool_server_map = {}
@classmethod
def from_resource(
cls: Type[T],
resource: Optional[Resource],
expected_type: Optional[ResourceType] = None,
) -> List[T]:
"""Create a resource from another resource."""
if not resource:
return []
if isinstance(resource, ToolPack):
return [cast(T, resource)]
tools = super().from_resource(resource, ResourceType.Tool)
if not tools:
return []
typed_tools = [cast(BaseTool, t) for t in tools]
return [ToolPack(typed_tools)] # type: ignore
def _get_call_args(self, arguments: Dict[str, Any], tl: BaseTool) -> Dict[str, Any]:
"""Get the call arguments."""
# Delete non-defined parameters
diff_args = list(set(arguments.keys()).difference(set(tl.args.keys())))
for arg_name in diff_args:
del arguments[arg_name]
# Rebuild dbgpt mcp call param
return {
"server": self.tool_server_map[tl.name],
"args": arguments,
"tool_name": tl.name,
}
def switch_mcp_input_schema(self, input_schema: dict):
args = {}
try:
properties = input_schema["properties"]
required = input_schema["required"]
required = input_schema.get("required", [])
for k, v in properties.items():
arg = {}
@ -339,9 +330,29 @@ class MCPToolPack(ToolPack):
await session.initialize()
result = await session.list_tools()
for tool in result.tools:
self.tool_server_map[tool.name] = server
tool_name = tool.name
self.tool_server_map[tool_name] = server
args = self.switch_mcp_input_schema(tool.inputSchema)
async def call_mcp_tool(
tool_name=tool_name, server=server, **kwargs
):
try:
async with sse_client(url=server) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
return await session.call_tool(
tool_name, arguments=kwargs
)
except Exception as e:
raise ValueError(f"MCP Call Exception! {str(e)}")
self.add_command(
tool.description, tool.name, args, call_mcp_tool
tool.description,
tool_name,
args,
call_mcp_tool,
parse_execute_args_func=json_parse_execute_args_func,
)
self._loaded = True

View File

@ -45,7 +45,7 @@ def extract_char_position(error_message: str) -> int:
raise ValueError("Character position not found in the error message.")
def find_json_objects(text):
def find_json_objects(text: str):
json_objects = []
inside_string = False
escape_character = False
@ -93,6 +93,16 @@ def find_json_objects(text):
return json_objects
def parse_or_raise_error(text: str, is_array: bool = False):
if not text:
return None
parsed_objs = find_json_objects(text)
if not parsed_objs:
# Use json.loads to raise raw error
return json.loads(text)
return parsed_objs if is_array else parsed_objs[0]
@staticmethod
def _format_json_str(jstr):
"""Remove newlines outside of quotes, and handle JSON escape sequences.