feat(agent): Compatible with MCP and tools (#2566)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Fangyin Cheng 2025-04-01 16:03:52 +08:00 committed by GitHub
parent 9719c0abcd
commit 59e34e2d89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 160 additions and 72 deletions

View File

@ -3,6 +3,7 @@ import logging
from typing import Optional from typing import Optional
from dbgpt.agent import Action, ActionOutput, AgentResource, Resource, ResourceType from dbgpt.agent import Action, ActionOutput, AgentResource, Resource, ResourceType
from dbgpt.util.json_utils import parse_or_raise_error
from ...resource.tool.base import BaseTool, ToolParameter from ...resource.tool.base import BaseTool, ToolParameter
from ...util.react_parser import ReActOutputParser, ReActStep from ...util.react_parser import ReActOutputParser, ReActStep
@ -166,8 +167,9 @@ class ReActAction(ToolAction):
action_input = parsed_step.action_input action_input = parsed_step.action_input
action_input_str = action_input action_input_str = action_input
try: try:
# Try to parse the action input to dict
if action_input and isinstance(action_input, str): if action_input and isinstance(action_input, str):
tool_args = json.loads(action_input) tool_args = parse_or_raise_error(action_input)
elif isinstance(action_input, dict): elif isinstance(action_input, dict):
tool_args = action_input tool_args = action_input
action_input_str = json.dumps(action_input, ensure_ascii=False) action_input_str = json.dumps(action_input, ensure_ascii=False)
@ -181,6 +183,7 @@ class ReActAction(ToolAction):
self.resource, self.resource,
self.render_protocol, self.render_protocol,
need_vis_render=need_vis_render, need_vis_render=need_vis_render,
raw_tool_input=action_input_str,
) )
if not act_out.action_input: if not act_out.action_input:
act_out.action_input = action_input_str act_out.action_input = action_input_str

View File

@ -114,6 +114,7 @@ async def run_tool(
resource: Resource, resource: Resource,
render_protocol: Optional[Vis] = None, render_protocol: Optional[Vis] = None,
need_vis_render: bool = False, need_vis_render: bool = False,
raw_tool_input: Optional[str] = None,
) -> ActionOutput: ) -> ActionOutput:
"""Run the tool.""" """Run the tool."""
is_terminal = None is_terminal = None
@ -121,10 +122,22 @@ async def run_tool(
tool_packs = ToolPack.from_resource(resource) tool_packs = ToolPack.from_resource(resource)
if not tool_packs: if not tool_packs:
raise ValueError("The tool resource is not found") raise ValueError("The tool resource is not found")
tool_pack = tool_packs[0] tool_pack: ToolPack = tool_packs[0]
response_success = True response_success = True
status = Status.RUNNING.value status = Status.RUNNING.value
err_msg = None err_msg = None
if raw_tool_input and tool_pack.parse_execute_args(
resource_name=name, input_str=raw_tool_input
):
# Use real tool to parse the input, it will raise raw error when failed
# it will make agent to modify the input and retry
parsed_args = tool_pack.parse_execute_args(
resource_name=name, input_str=raw_tool_input
)
if parsed_args and isinstance(parsed_args, tuple):
args = parsed_args[1]
try: try:
tool_result = await tool_pack.async_execute(resource_name=name, **args) tool_result = await tool_pack.async_execute(resource_name=name, **args)
status = Status.COMPLETE.value status = Status.COMPLETE.value

View File

@ -31,6 +31,11 @@ T = TypeVar("T", bound="Resource")
_DEFAULT_RESOURCE_NAME = _("My Agent Resource") _DEFAULT_RESOURCE_NAME = _("My Agent Resource")
_DEFAULT_RESOURCE_NAME_DESCRIPTION = _("Resource name") _DEFAULT_RESOURCE_NAME_DESCRIPTION = _("Resource name")
ARGS_TYPE = Tuple[Any, ...]
KWARGS_TYPE = Dict[str, Any]
EXECUTE_ARGS_TYPE = Tuple[ARGS_TYPE, KWARGS_TYPE]
PARSE_EXECUTE_ARGS_FUNCTION = Callable[[Optional[str]], Optional[EXECUTE_ARGS_TYPE]]
class ResourceType(str, Enum): class ResourceType(str, Enum):
"""Resource type enumeration.""" """Resource type enumeration."""
@ -190,6 +195,31 @@ class Resource(ABC, Generic[P]):
"""Get the resources.""" """Get the resources."""
raise NotImplementedError raise NotImplementedError
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Try to parse the execute arguments.
If return None, the execute parameters are pass by raw arguments generated by
Agent.
What case you should implement this method:
1. If the resource has a specific input format, you should implement this
2. You want parse and raise the raw error information, you should implement this
Args:
resource_name(str): The resource name.
input_str(str): The input string, it is the raw input string generated by
Agent.
Returns:
Optional[EXECUTE_ARGS_TYPE]: The execute arguments.
Raises:
ValueError: If the input string is invalid.
"""
return None
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any: def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
"""Execute the resource.""" """Execute the resource."""
raise NotImplementedError raise NotImplementedError

View File

@ -49,6 +49,11 @@ class ResourcePack(Resource[PackResourceParameters]):
"""Get the resource by name.""" """Get the resource by name."""
return self._resources.get(name, None) return self._resources.get(name, None)
async def preload_resource(self):
"""Preload the resource."""
for sub_resource in self.sub_resources:
await sub_resource.preload_resource()
async def get_prompt( async def get_prompt(
self, self,
*, *,

View File

@ -12,7 +12,13 @@ from dbgpt._private.pydantic import BaseModel, Field, model_validator
from dbgpt.util.configure.base import _MISSING, _MISSING_TYPE from dbgpt.util.configure.base import _MISSING, _MISSING_TYPE
from dbgpt.util.function_utils import parse_param_description, type_to_string from dbgpt.util.function_utils import parse_param_description, type_to_string
from ..base import Resource, ResourceParameters, ResourceType from ..base import (
EXECUTE_ARGS_TYPE,
PARSE_EXECUTE_ARGS_FUNCTION,
Resource,
ResourceParameters,
ResourceType,
)
ToolFunc = Union[Callable[..., Any], Callable[..., Awaitable[Any]]] ToolFunc = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]
@ -144,6 +150,7 @@ class FunctionTool(BaseTool):
description: Optional[str] = None, description: Optional[str] = None,
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None, args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
): ):
"""Create a tool from a function.""" """Create a tool from a function."""
if not description: if not description:
@ -155,6 +162,7 @@ class FunctionTool(BaseTool):
self._args: Dict[str, ToolParameter] = _parse_args(func, args, args_schema) self._args: Dict[str, ToolParameter] = _parse_args(func, args, args_schema)
self._func = func self._func = func
self._is_async = asyncio.iscoroutinefunction(func) self._is_async = asyncio.iscoroutinefunction(func)
self._parse_execute_args_func = parse_execute_args_func
@property @property
def name(self) -> str: def name(self) -> str:
@ -176,6 +184,14 @@ class FunctionTool(BaseTool):
"""Return whether the tool is asynchronous.""" """Return whether the tool is asynchronous."""
return self._is_async return self._is_async
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
if self._parse_execute_args_func is not None:
return self._parse_execute_args_func(input_str)
return None
def execute( def execute(
self, self,
*args, *args,

View File

@ -1,17 +1,22 @@
"""Tool resource pack module.""" """Tool resource pack module."""
import logging
import os import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from ..base import ResourceType, T from dbgpt.util.json_utils import parse_or_raise_error
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
from ..pack import Resource, ResourcePack from ..pack import Resource, ResourcePack
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
from .exceptions import ToolExecutionException, ToolNotFoundException from .exceptions import ToolExecutionException, ToolNotFoundException
ToolResourceType = Union[BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]] ToolResourceType = Union[Resource, BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
logger = logging.getLogger(__name__)
def _is_function_tool(resources: Any) -> bool: def _is_function_tool(resources: Any) -> bool:
@ -28,23 +33,41 @@ def _is_tool(resources: Any) -> bool:
return isinstance(resources, BaseTool) or _is_function_tool(resources) return isinstance(resources, BaseTool) or _is_function_tool(resources)
def _to_tool_list(resources: ToolResourceType) -> List[BaseTool]: def _to_tool_list(
if isinstance(resources, BaseTool): resources: ToolResourceType, unpack: bool = False, ignore_error: bool = False
return [resources] ) -> List[Resource]:
elif isinstance(resources, Sequence) and all(_is_tool(r) for r in resources): def parse_tool(r):
new_resources = []
for r in resources:
if isinstance(r, BaseTool): if isinstance(r, BaseTool):
new_resources.append(r) return [r]
elif _is_function_tool(r):
return [cast(FunctionTool, getattr(r, "_tool"))]
elif isinstance(r, ResourcePack):
if not unpack:
return [r]
new_list = []
for p in r.sub_resources:
new_list.extend(parse_tool(p))
return new_list
elif isinstance(r, Sequence):
new_list = []
for t in r:
new_list.extend(parse_tool(t))
return new_list
elif ignore_error:
return []
else: else:
function_tool = cast(FunctionTool, getattr(r, "_tool"))
new_resources.append(function_tool)
return new_resources
elif _is_function_tool(resources):
function_tool = cast(FunctionTool, getattr(resources, "_tool"))
return [function_tool]
raise ValueError("Invalid tool resource type") raise ValueError("Invalid tool resource type")
return parse_tool(resources)
def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
# The position arguments is empty
args = ()
kwargs = parse_or_raise_error(input_str)
return args, kwargs
class ToolPack(ResourcePack): class ToolPack(ResourcePack):
"""Tool resource pack class.""" """Tool resource pack class."""
@ -65,11 +88,7 @@ class ToolPack(ResourcePack):
"""Create a resource from another resource.""" """Create a resource from another resource."""
if not resource: if not resource:
return [] return []
if isinstance(resource, ToolPack): tools = _to_tool_list(resource, unpack=True, ignore_error=True)
return [cast(T, resource)]
tools = super().from_resource(resource, ResourceType.Tool)
if not tools:
return []
typed_tools = [cast(BaseTool, t) for t in tools] typed_tools = [cast(BaseTool, t) for t in tools]
return [ToolPack(typed_tools)] # type: ignore return [ToolPack(typed_tools)] # type: ignore
@ -79,6 +98,7 @@ class ToolPack(ResourcePack):
command_name: str, command_name: str,
args: Optional[Dict[str, Any]] = None, args: Optional[Dict[str, Any]] = None,
function: Optional[Callable] = None, function: Optional[Callable] = None,
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
) -> None: ) -> None:
"""Add a command to the commands. """Add a command to the commands.
@ -93,6 +113,8 @@ class ToolPack(ResourcePack):
values. Defaults to None. values. Defaults to None.
function (callable, optional): A callable function to be called when function (callable, optional): A callable function to be called when
the command is executed. Defaults to None. the command is executed. Defaults to None.
parse_execute_args (callable, optional): A callable function to parse the
execute arguments. Defaults to None.
""" """
if args is not None: if args is not None:
tool_args = {} tool_args = {}
@ -124,6 +146,7 @@ class ToolPack(ResourcePack):
func=function, func=function,
args=tool_args, args=tool_args,
description=command_label, description=command_label,
parse_execute_args_func=parse_execute_args_func,
) )
self.append(ft) self.append(ft)
@ -143,6 +166,16 @@ class ToolPack(ResourcePack):
del arguments[arg_name] del arguments[arg_name]
return arguments return arguments
def parse_execute_args(
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
) -> Optional[EXECUTE_ARGS_TYPE]:
"""Parse the execute arguments."""
try:
tl = self._get_execution_tool(resource_name)
return tl.parse_execute_args(input_str=input_str)
except ToolNotFoundException:
return None
def execute( def execute(
self, self,
*args, *args,
@ -244,17 +277,6 @@ class AutoGPTPluginToolPack(ToolPack):
self._loaded = True self._loaded = True
async def call_mcp_tool(server, tool_name, args: dict):
try:
async with sse_client(url=server) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
return await session.call_tool(tool_name, arguments=args)
except Exception:
raise ValueError("MCP Call Exception!{str(e)}")
class MCPToolPack(ToolPack): class MCPToolPack(ToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs): def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
"""Create an Auto-GPT plugin tool pack.""" """Create an Auto-GPT plugin tool pack."""
@ -263,42 +285,11 @@ class MCPToolPack(ToolPack):
self._loaded = False self._loaded = False
self.tool_server_map = {} self.tool_server_map = {}
@classmethod
def from_resource(
cls: Type[T],
resource: Optional[Resource],
expected_type: Optional[ResourceType] = None,
) -> List[T]:
"""Create a resource from another resource."""
if not resource:
return []
if isinstance(resource, ToolPack):
return [cast(T, resource)]
tools = super().from_resource(resource, ResourceType.Tool)
if not tools:
return []
typed_tools = [cast(BaseTool, t) for t in tools]
return [ToolPack(typed_tools)] # type: ignore
def _get_call_args(self, arguments: Dict[str, Any], tl: BaseTool) -> Dict[str, Any]:
"""Get the call arguments."""
# Delete non-defined parameters
diff_args = list(set(arguments.keys()).difference(set(tl.args.keys())))
for arg_name in diff_args:
del arguments[arg_name]
# Rebuild dbgpt mcp call param
return {
"server": self.tool_server_map[tl.name],
"args": arguments,
"tool_name": tl.name,
}
def switch_mcp_input_schema(self, input_schema: dict): def switch_mcp_input_schema(self, input_schema: dict):
args = {} args = {}
try: try:
properties = input_schema["properties"] properties = input_schema["properties"]
required = input_schema["required"] required = input_schema.get("required", [])
for k, v in properties.items(): for k, v in properties.items():
arg = {} arg = {}
@ -339,9 +330,29 @@ class MCPToolPack(ToolPack):
await session.initialize() await session.initialize()
result = await session.list_tools() result = await session.list_tools()
for tool in result.tools: for tool in result.tools:
self.tool_server_map[tool.name] = server tool_name = tool.name
self.tool_server_map[tool_name] = server
args = self.switch_mcp_input_schema(tool.inputSchema) args = self.switch_mcp_input_schema(tool.inputSchema)
async def call_mcp_tool(
tool_name=tool_name, server=server, **kwargs
):
try:
async with sse_client(url=server) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
return await session.call_tool(
tool_name, arguments=kwargs
)
except Exception as e:
raise ValueError(f"MCP Call Exception! {str(e)}")
self.add_command( self.add_command(
tool.description, tool.name, args, call_mcp_tool tool.description,
tool_name,
args,
call_mcp_tool,
parse_execute_args_func=json_parse_execute_args_func,
) )
self._loaded = True self._loaded = True

View File

@ -45,7 +45,7 @@ def extract_char_position(error_message: str) -> int:
raise ValueError("Character position not found in the error message.") raise ValueError("Character position not found in the error message.")
def find_json_objects(text): def find_json_objects(text: str):
json_objects = [] json_objects = []
inside_string = False inside_string = False
escape_character = False escape_character = False
@ -93,6 +93,16 @@ def find_json_objects(text):
return json_objects return json_objects
def parse_or_raise_error(text: str, is_array: bool = False):
if not text:
return None
parsed_objs = find_json_objects(text)
if not parsed_objs:
# Use json.loads to raise raw error
return json.loads(text)
return parsed_objs if is_array else parsed_objs[0]
@staticmethod @staticmethod
def _format_json_str(jstr): def _format_json_str(jstr):
"""Remove newlines outside of quotes, and handle JSON escape sequences. """Remove newlines outside of quotes, and handle JSON escape sequences.