mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 04:12:13 +00:00
feat(agent): Compatible with MCP and tools (#2566)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
9719c0abcd
commit
59e34e2d89
@ -3,6 +3,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.agent import Action, ActionOutput, AgentResource, Resource, ResourceType
|
||||
from dbgpt.util.json_utils import parse_or_raise_error
|
||||
|
||||
from ...resource.tool.base import BaseTool, ToolParameter
|
||||
from ...util.react_parser import ReActOutputParser, ReActStep
|
||||
@ -166,8 +167,9 @@ class ReActAction(ToolAction):
|
||||
action_input = parsed_step.action_input
|
||||
action_input_str = action_input
|
||||
try:
|
||||
# Try to parse the action input to dict
|
||||
if action_input and isinstance(action_input, str):
|
||||
tool_args = json.loads(action_input)
|
||||
tool_args = parse_or_raise_error(action_input)
|
||||
elif isinstance(action_input, dict):
|
||||
tool_args = action_input
|
||||
action_input_str = json.dumps(action_input, ensure_ascii=False)
|
||||
@ -181,6 +183,7 @@ class ReActAction(ToolAction):
|
||||
self.resource,
|
||||
self.render_protocol,
|
||||
need_vis_render=need_vis_render,
|
||||
raw_tool_input=action_input_str,
|
||||
)
|
||||
if not act_out.action_input:
|
||||
act_out.action_input = action_input_str
|
||||
|
@ -114,6 +114,7 @@ async def run_tool(
|
||||
resource: Resource,
|
||||
render_protocol: Optional[Vis] = None,
|
||||
need_vis_render: bool = False,
|
||||
raw_tool_input: Optional[str] = None,
|
||||
) -> ActionOutput:
|
||||
"""Run the tool."""
|
||||
is_terminal = None
|
||||
@ -121,10 +122,22 @@ async def run_tool(
|
||||
tool_packs = ToolPack.from_resource(resource)
|
||||
if not tool_packs:
|
||||
raise ValueError("The tool resource is not found!")
|
||||
tool_pack = tool_packs[0]
|
||||
tool_pack: ToolPack = tool_packs[0]
|
||||
response_success = True
|
||||
status = Status.RUNNING.value
|
||||
err_msg = None
|
||||
|
||||
if raw_tool_input and tool_pack.parse_execute_args(
|
||||
resource_name=name, input_str=raw_tool_input
|
||||
):
|
||||
# Use real tool to parse the input, it will raise raw error when failed
|
||||
# it will make agent to modify the input and retry
|
||||
parsed_args = tool_pack.parse_execute_args(
|
||||
resource_name=name, input_str=raw_tool_input
|
||||
)
|
||||
if parsed_args and isinstance(parsed_args, tuple):
|
||||
args = parsed_args[1]
|
||||
|
||||
try:
|
||||
tool_result = await tool_pack.async_execute(resource_name=name, **args)
|
||||
status = Status.COMPLETE.value
|
||||
|
@ -31,6 +31,11 @@ T = TypeVar("T", bound="Resource")
|
||||
_DEFAULT_RESOURCE_NAME = _("My Agent Resource")
|
||||
_DEFAULT_RESOURCE_NAME_DESCRIPTION = _("Resource name")
|
||||
|
||||
ARGS_TYPE = Tuple[Any, ...]
|
||||
KWARGS_TYPE = Dict[str, Any]
|
||||
EXECUTE_ARGS_TYPE = Tuple[ARGS_TYPE, KWARGS_TYPE]
|
||||
PARSE_EXECUTE_ARGS_FUNCTION = Callable[[Optional[str]], Optional[EXECUTE_ARGS_TYPE]]
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Resource type enumeration."""
|
||||
@ -190,6 +195,31 @@ class Resource(ABC, Generic[P]):
|
||||
"""Get the resources."""
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_execute_args(
|
||||
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
|
||||
) -> Optional[EXECUTE_ARGS_TYPE]:
|
||||
"""Try to parse the execute arguments.
|
||||
|
||||
If return None, the execute parameters are pass by raw arguments generated by
|
||||
Agent.
|
||||
|
||||
What case you should implement this method:
|
||||
1. If the resource has a specific input format, you should implement this
|
||||
2. You want parse and raise the raw error information, you should implement this
|
||||
|
||||
Args:
|
||||
resource_name(str): The resource name.
|
||||
input_str(str): The input string, it is the raw input string generated by
|
||||
Agent.
|
||||
|
||||
Returns:
|
||||
Optional[EXECUTE_ARGS_TYPE]: The execute arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input string is invalid.
|
||||
"""
|
||||
return None
|
||||
|
||||
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
|
||||
"""Execute the resource."""
|
||||
raise NotImplementedError
|
||||
|
@ -49,6 +49,11 @@ class ResourcePack(Resource[PackResourceParameters]):
|
||||
"""Get the resource by name."""
|
||||
return self._resources.get(name, None)
|
||||
|
||||
async def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
for sub_resource in self.sub_resources:
|
||||
await sub_resource.preload_resource()
|
||||
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
|
@ -12,7 +12,13 @@ from dbgpt._private.pydantic import BaseModel, Field, model_validator
|
||||
from dbgpt.util.configure.base import _MISSING, _MISSING_TYPE
|
||||
from dbgpt.util.function_utils import parse_param_description, type_to_string
|
||||
|
||||
from ..base import Resource, ResourceParameters, ResourceType
|
||||
from ..base import (
|
||||
EXECUTE_ARGS_TYPE,
|
||||
PARSE_EXECUTE_ARGS_FUNCTION,
|
||||
Resource,
|
||||
ResourceParameters,
|
||||
ResourceType,
|
||||
)
|
||||
|
||||
ToolFunc = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]
|
||||
|
||||
@ -144,6 +150,7 @@ class FunctionTool(BaseTool):
|
||||
description: Optional[str] = None,
|
||||
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
|
||||
):
|
||||
"""Create a tool from a function."""
|
||||
if not description:
|
||||
@ -155,6 +162,7 @@ class FunctionTool(BaseTool):
|
||||
self._args: Dict[str, ToolParameter] = _parse_args(func, args, args_schema)
|
||||
self._func = func
|
||||
self._is_async = asyncio.iscoroutinefunction(func)
|
||||
self._parse_execute_args_func = parse_execute_args_func
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -176,6 +184,14 @@ class FunctionTool(BaseTool):
|
||||
"""Return whether the tool is asynchronous."""
|
||||
return self._is_async
|
||||
|
||||
def parse_execute_args(
|
||||
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
|
||||
) -> Optional[EXECUTE_ARGS_TYPE]:
|
||||
"""Parse the execute arguments."""
|
||||
if self._parse_execute_args_func is not None:
|
||||
return self._parse_execute_args_func(input_str)
|
||||
return None
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*args,
|
||||
|
@ -1,17 +1,22 @@
|
||||
"""Tool resource pack module."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from ..base import ResourceType, T
|
||||
from dbgpt.util.json_utils import parse_or_raise_error
|
||||
|
||||
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
|
||||
from ..pack import Resource, ResourcePack
|
||||
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
|
||||
from .exceptions import ToolExecutionException, ToolNotFoundException
|
||||
|
||||
ToolResourceType = Union[BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
|
||||
ToolResourceType = Union[Resource, BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_function_tool(resources: Any) -> bool:
|
||||
@ -28,22 +33,40 @@ def _is_tool(resources: Any) -> bool:
|
||||
return isinstance(resources, BaseTool) or _is_function_tool(resources)
|
||||
|
||||
|
||||
def _to_tool_list(resources: ToolResourceType) -> List[BaseTool]:
|
||||
if isinstance(resources, BaseTool):
|
||||
return [resources]
|
||||
elif isinstance(resources, Sequence) and all(_is_tool(r) for r in resources):
|
||||
new_resources = []
|
||||
for r in resources:
|
||||
if isinstance(r, BaseTool):
|
||||
new_resources.append(r)
|
||||
else:
|
||||
function_tool = cast(FunctionTool, getattr(r, "_tool"))
|
||||
new_resources.append(function_tool)
|
||||
return new_resources
|
||||
elif _is_function_tool(resources):
|
||||
function_tool = cast(FunctionTool, getattr(resources, "_tool"))
|
||||
return [function_tool]
|
||||
raise ValueError("Invalid tool resource type")
|
||||
def _to_tool_list(
|
||||
resources: ToolResourceType, unpack: bool = False, ignore_error: bool = False
|
||||
) -> List[Resource]:
|
||||
def parse_tool(r):
|
||||
if isinstance(r, BaseTool):
|
||||
return [r]
|
||||
elif _is_function_tool(r):
|
||||
return [cast(FunctionTool, getattr(r, "_tool"))]
|
||||
elif isinstance(r, ResourcePack):
|
||||
if not unpack:
|
||||
return [r]
|
||||
new_list = []
|
||||
for p in r.sub_resources:
|
||||
new_list.extend(parse_tool(p))
|
||||
return new_list
|
||||
elif isinstance(r, Sequence):
|
||||
new_list = []
|
||||
for t in r:
|
||||
new_list.extend(parse_tool(t))
|
||||
return new_list
|
||||
elif ignore_error:
|
||||
return []
|
||||
else:
|
||||
raise ValueError("Invalid tool resource type")
|
||||
|
||||
return parse_tool(resources)
|
||||
|
||||
|
||||
def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
|
||||
"""Parse the execute arguments."""
|
||||
# The position arguments is empty
|
||||
args = ()
|
||||
kwargs = parse_or_raise_error(input_str)
|
||||
return args, kwargs
|
||||
|
||||
|
||||
class ToolPack(ResourcePack):
|
||||
@ -65,11 +88,7 @@ class ToolPack(ResourcePack):
|
||||
"""Create a resource from another resource."""
|
||||
if not resource:
|
||||
return []
|
||||
if isinstance(resource, ToolPack):
|
||||
return [cast(T, resource)]
|
||||
tools = super().from_resource(resource, ResourceType.Tool)
|
||||
if not tools:
|
||||
return []
|
||||
tools = _to_tool_list(resource, unpack=True, ignore_error=True)
|
||||
typed_tools = [cast(BaseTool, t) for t in tools]
|
||||
return [ToolPack(typed_tools)] # type: ignore
|
||||
|
||||
@ -79,6 +98,7 @@ class ToolPack(ResourcePack):
|
||||
command_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
function: Optional[Callable] = None,
|
||||
parse_execute_args_func: Optional[PARSE_EXECUTE_ARGS_FUNCTION] = None,
|
||||
) -> None:
|
||||
"""Add a command to the commands.
|
||||
|
||||
@ -93,6 +113,8 @@ class ToolPack(ResourcePack):
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
parse_execute_args (callable, optional): A callable function to parse the
|
||||
execute arguments. Defaults to None.
|
||||
"""
|
||||
if args is not None:
|
||||
tool_args = {}
|
||||
@ -124,6 +146,7 @@ class ToolPack(ResourcePack):
|
||||
func=function,
|
||||
args=tool_args,
|
||||
description=command_label,
|
||||
parse_execute_args_func=parse_execute_args_func,
|
||||
)
|
||||
self.append(ft)
|
||||
|
||||
@ -143,6 +166,16 @@ class ToolPack(ResourcePack):
|
||||
del arguments[arg_name]
|
||||
return arguments
|
||||
|
||||
def parse_execute_args(
|
||||
self, resource_name: Optional[str] = None, input_str: Optional[str] = None
|
||||
) -> Optional[EXECUTE_ARGS_TYPE]:
|
||||
"""Parse the execute arguments."""
|
||||
try:
|
||||
tl = self._get_execution_tool(resource_name)
|
||||
return tl.parse_execute_args(input_str=input_str)
|
||||
except ToolNotFoundException:
|
||||
return None
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*args,
|
||||
@ -244,17 +277,6 @@ class AutoGPTPluginToolPack(ToolPack):
|
||||
self._loaded = True
|
||||
|
||||
|
||||
async def call_mcp_tool(server, tool_name, args: dict):
|
||||
try:
|
||||
async with sse_client(url=server) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
return await session.call_tool(tool_name, arguments=args)
|
||||
except Exception:
|
||||
raise ValueError("MCP Call Exception!{str(e)}")
|
||||
|
||||
|
||||
class MCPToolPack(ToolPack):
|
||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||
"""Create an Auto-GPT plugin tool pack."""
|
||||
@ -263,42 +285,11 @@ class MCPToolPack(ToolPack):
|
||||
self._loaded = False
|
||||
self.tool_server_map = {}
|
||||
|
||||
@classmethod
|
||||
def from_resource(
|
||||
cls: Type[T],
|
||||
resource: Optional[Resource],
|
||||
expected_type: Optional[ResourceType] = None,
|
||||
) -> List[T]:
|
||||
"""Create a resource from another resource."""
|
||||
if not resource:
|
||||
return []
|
||||
if isinstance(resource, ToolPack):
|
||||
return [cast(T, resource)]
|
||||
tools = super().from_resource(resource, ResourceType.Tool)
|
||||
if not tools:
|
||||
return []
|
||||
typed_tools = [cast(BaseTool, t) for t in tools]
|
||||
return [ToolPack(typed_tools)] # type: ignore
|
||||
|
||||
def _get_call_args(self, arguments: Dict[str, Any], tl: BaseTool) -> Dict[str, Any]:
|
||||
"""Get the call arguments."""
|
||||
# Delete non-defined parameters
|
||||
diff_args = list(set(arguments.keys()).difference(set(tl.args.keys())))
|
||||
for arg_name in diff_args:
|
||||
del arguments[arg_name]
|
||||
|
||||
# Rebuild dbgpt mcp call param
|
||||
return {
|
||||
"server": self.tool_server_map[tl.name],
|
||||
"args": arguments,
|
||||
"tool_name": tl.name,
|
||||
}
|
||||
|
||||
def switch_mcp_input_schema(self, input_schema: dict):
|
||||
args = {}
|
||||
try:
|
||||
properties = input_schema["properties"]
|
||||
required = input_schema["required"]
|
||||
required = input_schema.get("required", [])
|
||||
for k, v in properties.items():
|
||||
arg = {}
|
||||
|
||||
@ -339,9 +330,29 @@ class MCPToolPack(ToolPack):
|
||||
await session.initialize()
|
||||
result = await session.list_tools()
|
||||
for tool in result.tools:
|
||||
self.tool_server_map[tool.name] = server
|
||||
tool_name = tool.name
|
||||
self.tool_server_map[tool_name] = server
|
||||
args = self.switch_mcp_input_schema(tool.inputSchema)
|
||||
|
||||
async def call_mcp_tool(
|
||||
tool_name=tool_name, server=server, **kwargs
|
||||
):
|
||||
try:
|
||||
async with sse_client(url=server) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
return await session.call_tool(
|
||||
tool_name, arguments=kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"MCP Call Exception! {str(e)}")
|
||||
|
||||
self.add_command(
|
||||
tool.description, tool.name, args, call_mcp_tool
|
||||
tool.description,
|
||||
tool_name,
|
||||
args,
|
||||
call_mcp_tool,
|
||||
parse_execute_args_func=json_parse_execute_args_func,
|
||||
)
|
||||
self._loaded = True
|
||||
|
@ -45,7 +45,7 @@ def extract_char_position(error_message: str) -> int:
|
||||
raise ValueError("Character position not found in the error message.")
|
||||
|
||||
|
||||
def find_json_objects(text):
|
||||
def find_json_objects(text: str):
|
||||
json_objects = []
|
||||
inside_string = False
|
||||
escape_character = False
|
||||
@ -93,6 +93,16 @@ def find_json_objects(text):
|
||||
return json_objects
|
||||
|
||||
|
||||
def parse_or_raise_error(text: str, is_array: bool = False):
|
||||
if not text:
|
||||
return None
|
||||
parsed_objs = find_json_objects(text)
|
||||
if not parsed_objs:
|
||||
# Use json.loads to raise raw error
|
||||
return json.loads(text)
|
||||
return parsed_objs if is_array else parsed_objs[0]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _format_json_str(jstr):
|
||||
"""Remove newlines outside of quotes, and handle JSON escape sequences.
|
||||
|
Loading…
Reference in New Issue
Block a user