From 2e0e140b51b7f069912c40f13dfaf35175f1517c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 3 Apr 2025 09:21:05 +0800 Subject: [PATCH] feat(agent): Support MCP authentication --- .../src/dbgpt/agent/core/base_agent.py | 2 +- .../agent/expand/actions/react_action.py | 10 ++++ .../src/dbgpt/agent/expand/react_agent.py | 36 ++++++++++++-- .../src/dbgpt/agent/resource/base.py | 7 ++- .../src/dbgpt/agent/resource/pack.py | 18 ++++++- .../src/dbgpt/agent/resource/tool/pack.py | 48 +++++++++++++++++-- .../src/dbgpt_serve/agent/resource/mcp.py | 34 ++++++++++++- 7 files changed, 143 insertions(+), 12 deletions(-) diff --git a/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py b/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py index be5c34872..cafa6315f 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py +++ b/packages/dbgpt-core/src/dbgpt/agent/core/base_agent.py @@ -139,7 +139,7 @@ class ConversableAgent(Role, Agent): async def preload_resource(self) -> None: """Preload resources before agent initialization.""" if self.resource: - await self.blocking_func_to_async(self.resource.preload_resource) + await self.resource.preload_resource() async def build(self, is_retry_chat: bool = False) -> "ConversableAgent": """Build the agent.""" diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py index b66de5380..5ee73a314 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py @@ -166,6 +166,16 @@ class ReActAction(ToolAction): name = parsed_step.action action_input = parsed_step.action_input action_input_str = action_input + + if not name: + terminal_content = str(action_input_str if action_input_str else ai_message) + return ActionOutput( + is_exe_success=True, + content=terminal_content, + observations=terminal_content, + terminate=True, + ) + try: # Try to parse the action input to dict if action_input and isinstance(action_input, str): diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py b/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py index 284366582..1135a4694 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/react_agent.py @@ -13,11 +13,11 @@ from dbgpt.agent import ( ResourceType, ) from dbgpt.agent.core.role import AgentRunMode -from dbgpt.agent.resource import BaseTool, ToolPack +from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack from dbgpt.agent.util.react_parser import ReActOutputParser from dbgpt.util.configure import DynConfig -from .actions.react_action import ReActAction +from .actions.react_action import ReActAction, Terminate logger = logging.getLogger(__name__) @@ -113,7 +113,7 @@ class ReActAgent(ConversableAgent): """Init indicator AssistantAgent.""" super().__init__(**kwargs) - self._init_actions([ReActAction]) + self._init_actions([ReActAction, Terminate]) async def _a_init_reply_message( self, @@ -150,6 +150,36 @@ class ReActAgent(ConversableAgent): } return reply_message + async def preload_resource(self) -> None: + await super().preload_resource() + self._check_and_add_terminate() + + def _check_and_add_terminate(self): + if not self.resource: + return + _is_has_terminal = False + + def _has_terminal(r: Resource): + nonlocal _is_has_terminal + if r.type() == ResourceType.Tool and isinstance(r, Terminate): + _is_has_terminal = True + return r + + _has_add_terminal = False + + def _add_terminate(r: Resource): + nonlocal _has_add_terminal + if not _has_add_terminal and isinstance(r, ResourcePack): + terminal = Terminate() + r._resources[terminal.name] = terminal + _has_add_terminal = True + return r + + self.resource.apply(apply_func=_has_terminal) + if not _is_has_terminal: + # Add terminal action to the resource + self.resource.apply(apply_pack_func=_add_terminate) + async def load_resource(self, question: str, is_retry_chat: bool = False): """Load agent bind resource.""" if self.resource: diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/base.py b/packages/dbgpt-core/src/dbgpt/agent/resource/base.py index ab06af33b..29715c070 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/base.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/base.py @@ -272,7 +272,12 @@ class Resource(ABC, Generic[P]): def apply( self, - apply_func: Callable[["Resource"], Union["Resource", List["Resource"], None]], + apply_func: Optional[ + Callable[["Resource"], Union["Resource", List["Resource"], None]] + ] = None, + apply_pack_func: Optional[ + Callable[["Resource"], Union["Resource", None]] + ] = None, ) -> Union["Resource", None]: """Apply the function to the resource.""" return self diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py b/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py index f79c1d954..fa243d39f 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/pack.py @@ -125,16 +125,27 @@ class ResourcePack(Resource[PackResourceParameters]): return list(self._resources.values()) def apply( - self, apply_func: Callable[[Resource], Union[Resource, List[Resource], None]] + self, + apply_func: Optional[ + Callable[[Resource], Union[Resource, List[Resource], None]] + ] = None, + apply_pack_func: Optional[ + Callable[["Resource"], Union["Resource", None]] + ] = None, ) -> Union[Resource, None]: """Apply the function to the resource.""" if not self.is_pack: return self + if not apply_func and not apply_pack_func: + raise ValueError("No function provided to apply to the resource pack.") + def _apply_func_to_resource( resource: Resource, ) -> Union[Resource, List[Resource], None]: if resource.is_pack: + if apply_pack_func is not None: + return apply_pack_func(resource) resources = [] resource_copy = cast(ResourcePack, copy.copy(resource)) for resource_copy in resource_copy.sub_resources: @@ -149,7 +160,10 @@ class ResourcePack(Resource[PackResourceParameters]): resource.name: resource for resource in resources } else: - return apply_func(resource) + if apply_func is not None: + return apply_func(resource) + else: + return resource new_resource = _apply_func_to_resource(self) resource_copy = cast(ResourcePack, copy.copy(self)) diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py index 6ecd92229..e0ec99163 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py @@ -278,12 +278,48 @@ class AutoGPTPluginToolPack(ToolPack): class MCPToolPack(ToolPack): - def __init__(self, mcp_servers: Union[str, List[str]], **kwargs): + """MCP tool pack class. + + Wrap the MCP SSE server as a tool pack. + + Example: + .. code-block:: python + + tools = MCPToolPack("http://127.0.0.1:8000/sse") + + If you want to pass the token to the server, you can use the headers parameter: + .. code-block:: python + + tools = MCPToolPack( + "http://127.0.0.1:8000/sse" + default_headers={"Authorization": "Bearer your_token"} + ) + # Set the default headers for ech server + tools2 = MCPToolPack( + "http://127.0.0.1:8000/sse" + headers = { + "http://127.0.0.1:8000/sse": { + "Authorization": "Bearer your_token" + } + } + ) + """ + + def __init__( + self, + mcp_servers: Union[str, List[str]], + headers: Optional[Dict[str, Dict[str, Any]]] = None, + default_headers: Optional[Dict[str, Any]] = None, + **kwargs, + ): """Create an Auto-GPT plugin tool pack.""" super().__init__([], **kwargs) self._mcp_servers = mcp_servers self._loaded = False self.tool_server_map = {} + self._default_headers = default_headers or {} + self._headers_map = headers or {} + self.server_headers_map = {} def switch_mcp_input_schema(self, input_schema: dict): args = {} @@ -324,7 +360,10 @@ class MCPToolPack(ToolPack): server_list = self._mcp_servers.split(";") for server in server_list: - async with sse_client(url=server) as (read, write): + server_headers = self._headers_map.get(server, self._default_headers) + self.server_headers_map[server] = server_headers + + async with sse_client(url=server, headers=server_headers) as (read, write): async with ClientSession(read, write) as session: # Initialize the connection await session.initialize() @@ -338,7 +377,10 @@ class MCPToolPack(ToolPack): tool_name=tool_name, server=server, **kwargs ): try: - async with sse_client(url=server) as (read, write): + headers_to_use = self.server_headers_map.get(server, {}) + async with sse_client( + url=server, headers=headers_to_use + ) as (read, write): async with ClientSession(read, write) as session: # Initialize the connection await session.initialize() diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py index 3f06ccecd..453982fb7 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py @@ -49,7 +49,27 @@ class MCPPackResourceParameters(PackResourceParameters): class MCPSSEToolPack(MCPToolPack): def __init__(self, mcp_servers: Union[str, List[str]], **kwargs): - super().__init__(mcp_servers=mcp_servers, **kwargs) + """Initialize the MCPSSEToolPack with the given MCP servers.""" + headers = {} + if "token" in kwargs and kwargs["token"]: + # token is not supported in sse mode + servers = ( + mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers + ) + tokens = ( + kwargs["token"].split(";") + if isinstance(kwargs["token"], str) + else kwargs["token"] + ) + if len(servers) == len(tokens): + for i, token in enumerate(tokens): + headers[servers[i]] = {"Authorization": f"Bearer {token}"} + else: + token = tokens[0] + for server in servers: + headers[server] = {"Authorization": f"Bearer {token}"} + kwargs.pop("token") + super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs) @classmethod def type_alias(cls) -> str: @@ -64,7 +84,17 @@ class MCPSSEToolPack(MCPToolPack): mcp_servers: str = dataclasses.field( default="http://127.0.0.1:8000/sse", metadata={ - "help": _("MCP SSE Server URL, split by ':'"), + "help": _("MCP SSE Server URL, split by ';'"), + }, + ) + token: Optional[str] = dataclasses.field( + default=None, + metadata={ + "help": _( + 'MCP SSE Server token, split by ";", It will be ' + 'added to the header({"Authorization": "Bearer your_token"}' + ), + "tags": "privacy", }, )