feat(agent): Support MCP authentication

This commit is contained in:
Fangyin Cheng 2025-04-03 09:21:05 +08:00
parent 0b6633703d
commit 2e0e140b51
7 changed files with 143 additions and 12 deletions

View File

@ -139,7 +139,7 @@ class ConversableAgent(Role, Agent):
async def preload_resource(self) -> None:
"""Preload resources before agent initialization."""
if self.resource:
await self.blocking_func_to_async(self.resource.preload_resource)
await self.resource.preload_resource()
async def build(self, is_retry_chat: bool = False) -> "ConversableAgent":
"""Build the agent."""

View File

@ -166,6 +166,16 @@ class ReActAction(ToolAction):
name = parsed_step.action
action_input = parsed_step.action_input
action_input_str = action_input
if not name:
terminal_content = str(action_input_str if action_input_str else ai_message)
return ActionOutput(
is_exe_success=True,
content=terminal_content,
observations=terminal_content,
terminate=True,
)
try:
# Try to parse the action input to dict
if action_input and isinstance(action_input, str):

View File

@ -13,11 +13,11 @@ from dbgpt.agent import (
ResourceType,
)
from dbgpt.agent.core.role import AgentRunMode
from dbgpt.agent.resource import BaseTool, ToolPack
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
from dbgpt.agent.util.react_parser import ReActOutputParser
from dbgpt.util.configure import DynConfig
from .actions.react_action import ReActAction
from .actions.react_action import ReActAction, Terminate
logger = logging.getLogger(__name__)
@ -113,7 +113,7 @@ class ReActAgent(ConversableAgent):
"""Init indicator AssistantAgent."""
super().__init__(**kwargs)
self._init_actions([ReActAction])
self._init_actions([ReActAction, Terminate])
async def _a_init_reply_message(
self,
@ -150,6 +150,36 @@ class ReActAgent(ConversableAgent):
}
return reply_message
async def preload_resource(self) -> None:
await super().preload_resource()
self._check_and_add_terminate()
def _check_and_add_terminate(self):
if not self.resource:
return
_is_has_terminal = False
def _has_terminal(r: Resource):
nonlocal _is_has_terminal
if r.type() == ResourceType.Tool and isinstance(r, Terminate):
_is_has_terminal = True
return r
_has_add_terminal = False
def _add_terminate(r: Resource):
nonlocal _has_add_terminal
if not _has_add_terminal and isinstance(r, ResourcePack):
terminal = Terminate()
r._resources[terminal.name] = terminal
_has_add_terminal = True
return r
self.resource.apply(apply_func=_has_terminal)
if not _is_has_terminal:
# Add terminal action to the resource
self.resource.apply(apply_pack_func=_add_terminate)
async def load_resource(self, question: str, is_retry_chat: bool = False):
"""Load agent bind resource."""
if self.resource:

View File

@ -272,7 +272,12 @@ class Resource(ABC, Generic[P]):
def apply(
self,
apply_func: Callable[["Resource"], Union["Resource", List["Resource"], None]],
apply_func: Optional[
Callable[["Resource"], Union["Resource", List["Resource"], None]]
] = None,
apply_pack_func: Optional[
Callable[["Resource"], Union["Resource", None]]
] = None,
) -> Union["Resource", None]:
"""Apply the function to the resource."""
return self

View File

@ -125,16 +125,27 @@ class ResourcePack(Resource[PackResourceParameters]):
return list(self._resources.values())
def apply(
self, apply_func: Callable[[Resource], Union[Resource, List[Resource], None]]
self,
apply_func: Optional[
Callable[[Resource], Union[Resource, List[Resource], None]]
] = None,
apply_pack_func: Optional[
Callable[["Resource"], Union["Resource", None]]
] = None,
) -> Union[Resource, None]:
"""Apply the function to the resource."""
if not self.is_pack:
return self
if not apply_func and not apply_pack_func:
raise ValueError("No function provided to apply to the resource pack.")
def _apply_func_to_resource(
resource: Resource,
) -> Union[Resource, List[Resource], None]:
if resource.is_pack:
if apply_pack_func is not None:
return apply_pack_func(resource)
resources = []
resource_copy = cast(ResourcePack, copy.copy(resource))
for resource_copy in resource_copy.sub_resources:
@ -149,7 +160,10 @@ class ResourcePack(Resource[PackResourceParameters]):
resource.name: resource for resource in resources
}
else:
return apply_func(resource)
if apply_func is not None:
return apply_func(resource)
else:
return resource
new_resource = _apply_func_to_resource(self)
resource_copy = cast(ResourcePack, copy.copy(self))

View File

@ -278,12 +278,48 @@ class AutoGPTPluginToolPack(ToolPack):
class MCPToolPack(ToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
"""MCP tool pack class.
Wrap the MCP SSE server as a tool pack.
Example:
.. code-block:: python
tools = MCPToolPack("http://127.0.0.1:8000/sse")
If you want to pass the token to the server, you can use the headers parameter:
.. code-block:: python
tools = MCPToolPack(
"http://127.0.0.1:8000/sse"
default_headers={"Authorization": "Bearer your_token"}
)
# Set the default headers for ech server
tools2 = MCPToolPack(
"http://127.0.0.1:8000/sse"
headers = {
"http://127.0.0.1:8000/sse": {
"Authorization": "Bearer your_token"
}
}
)
"""
def __init__(
self,
mcp_servers: Union[str, List[str]],
headers: Optional[Dict[str, Dict[str, Any]]] = None,
default_headers: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Create an Auto-GPT plugin tool pack."""
super().__init__([], **kwargs)
self._mcp_servers = mcp_servers
self._loaded = False
self.tool_server_map = {}
self._default_headers = default_headers or {}
self._headers_map = headers or {}
self.server_headers_map = {}
def switch_mcp_input_schema(self, input_schema: dict):
args = {}
@ -324,7 +360,10 @@ class MCPToolPack(ToolPack):
server_list = self._mcp_servers.split(";")
for server in server_list:
async with sse_client(url=server) as (read, write):
server_headers = self._headers_map.get(server, self._default_headers)
self.server_headers_map[server] = server_headers
async with sse_client(url=server, headers=server_headers) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
@ -338,7 +377,10 @@ class MCPToolPack(ToolPack):
tool_name=tool_name, server=server, **kwargs
):
try:
async with sse_client(url=server) as (read, write):
headers_to_use = self.server_headers_map.get(server, {})
async with sse_client(
url=server, headers=headers_to_use
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()

View File

@ -49,7 +49,27 @@ class MCPPackResourceParameters(PackResourceParameters):
class MCPSSEToolPack(MCPToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
super().__init__(mcp_servers=mcp_servers, **kwargs)
"""Initialize the MCPSSEToolPack with the given MCP servers."""
headers = {}
if "token" in kwargs and kwargs["token"]:
# token is not supported in sse mode
servers = (
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
)
tokens = (
kwargs["token"].split(";")
if isinstance(kwargs["token"], str)
else kwargs["token"]
)
if len(servers) == len(tokens):
for i, token in enumerate(tokens):
headers[servers[i]] = {"Authorization": f"Bearer {token}"}
else:
token = tokens[0]
for server in servers:
headers[server] = {"Authorization": f"Bearer {token}"}
kwargs.pop("token")
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
@classmethod
def type_alias(cls) -> str:
@ -64,7 +84,17 @@ class MCPSSEToolPack(MCPToolPack):
mcp_servers: str = dataclasses.field(
default="http://127.0.0.1:8000/sse",
metadata={
"help": _("MCP SSE Server URL, split by ':'"),
"help": _("MCP SSE Server URL, split by ';'"),
},
)
token: Optional[str] = dataclasses.field(
default=None,
metadata={
"help": _(
'MCP SSE Server token, split by ";", It will be '
'added to the header({"Authorization": "Bearer your_token"}'
),
"tags": "privacy",
},
)