From e9ce534ca1ed004b0ecf8e7db7e4b09589d68f5c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 8 Apr 2025 08:47:41 +0800 Subject: [PATCH] feat(agent): Support SSL/TLS for MCP (#2591) --- .../agent/expand/actions/react_action.py | 2 +- .../dbgpt/agent/expand/actions/tool_action.py | 4 + .../src/dbgpt/agent/resource/tool/pack.py | 60 ++++++- .../src/dbgpt/agent/util/mcp_utils.py | 147 ++++++++++++++++++ .../src/dbgpt/model/proxy/llms/chatgpt.py | 6 +- .../src/dbgpt_serve/agent/resource/mcp.py | 53 ++++++- 6 files changed, 261 insertions(+), 11 deletions(-) create mode 100644 packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py index 5ee73a314..566244633 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/react_action.py @@ -180,7 +180,7 @@ class ReActAction(ToolAction): # Try to parse the action input to dict if action_input and isinstance(action_input, str): tool_args = parse_or_raise_error(action_input) - elif isinstance(action_input, dict): + elif isinstance(action_input, dict) or isinstance(action_input, list): tool_args = action_input action_input_str = json.dumps(action_input, ensure_ascii=False) except json.JSONDecodeError: diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py index 5937d183f..dc65cf6b9 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/tool_action.py @@ -138,6 +138,10 @@ async def run_tool( if parsed_args and isinstance(parsed_args, tuple): args = parsed_args[1] + if args is not None and isinstance(args, list) and len(args) == 0: + # Input args is empty list, just use default args + args = {} + try: tool_result = await tool_pack.async_execute(resource_name=name, **args) status = Status.COMPLETE.value diff --git a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py index e0ec99163..1cbb8a4c1 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py +++ b/packages/dbgpt-core/src/dbgpt/agent/resource/tool/pack.py @@ -2,13 +2,14 @@ import logging import os +import ssl from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast from mcp import ClientSession -from mcp.client.sse import sse_client from dbgpt.util.json_utils import parse_or_raise_error +from ...util.mcp_utils import sse_client from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T from ..pack import Resource, ResourcePack from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc @@ -66,6 +67,8 @@ def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]: # The position arguments is empty args = () kwargs = parse_or_raise_error(input_str) + if kwargs is not None and isinstance(kwargs, list) and len(kwargs) == 0: + kwargs = {} return args, kwargs @@ -303,6 +306,37 @@ class MCPToolPack(ToolPack): } } ) + + If you want to set the ssl verify, you can use the ssl_verify parameter: + .. code-block:: python + + # Default ssl_verify is True + tools = MCPToolPack( + "https://your_ssl_domain/sse", + ) + + # Set the default ssl_verify to False to disable ssl verify + tools2 = MCPToolPack( + "https://your_ssl_domain/sse", default_ssl_verify=False + ) + + # With Custom CA file + tools3 = MCPToolPack( + "https://your_ssl_domain/sse", default_ssl_cafile="/path/to/your/ca.crt" + ) + + # Set the ssl_verify for each server + import ssl + + tools4 = MCPToolPack( + "https://your_ssl_domain/sse", + ssl_verify={ + "https://your_ssl_domain/sse": ssl.create_default_context( + cafile="/path/to/your/ca.crt" + ), + }, + ) + """ def __init__( @@ -310,6 +344,9 @@ class MCPToolPack(ToolPack): mcp_servers: Union[str, List[str]], headers: Optional[Dict[str, Dict[str, Any]]] = None, default_headers: Optional[Dict[str, Any]] = None, + ssl_verify: Optional[Dict[str, Union[ssl.SSLContext, str, bool]]] = None, + default_ssl_verify: Union[ssl.SSLContext, str, bool] = True, + default_ssl_cafile: Optional[str] = None, **kwargs, ): """Create an Auto-GPT plugin tool pack.""" @@ -320,6 +357,12 @@ class MCPToolPack(ToolPack): self._default_headers = default_headers or {} self._headers_map = headers or {} self.server_headers_map = {} + if default_ssl_cafile and not ssl_verify and default_ssl_verify: + default_ssl_verify = ssl.create_default_context(cafile=default_ssl_cafile) + + self._default_ssl_verify = default_ssl_verify + self._ssl_verify_map = ssl_verify or {} + self.server_ssl_verify_map = {} def switch_mcp_input_schema(self, input_schema: dict): args = {} @@ -362,8 +405,14 @@ class MCPToolPack(ToolPack): for server in server_list: server_headers = self._headers_map.get(server, self._default_headers) self.server_headers_map[server] = server_headers + server_ssl_verify = self._ssl_verify_map.get( + server, self._default_ssl_verify + ) + self.server_ssl_verify_map[server] = server_ssl_verify - async with sse_client(url=server, headers=server_headers) as (read, write): + async with sse_client( + url=server, headers=server_headers, verify=server_ssl_verify + ) as (read, write): async with ClientSession(read, write) as session: # Initialize the connection await session.initialize() @@ -378,8 +427,13 @@ class MCPToolPack(ToolPack): ): try: headers_to_use = self.server_headers_map.get(server, {}) + ssl_verify_to_use = self.server_ssl_verify_map.get( + server, True + ) async with sse_client( - url=server, headers=headers_to_use + url=server, + headers=headers_to_use, + verify=ssl_verify_to_use, ) as (read, write): async with ClientSession(read, write) as session: # Initialize the connection diff --git a/packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py b/packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py new file mode 100644 index 000000000..8f7a96297 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py @@ -0,0 +1,147 @@ +import logging +import ssl +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import urljoin, urlparse + +import anyio +import httpx +import mcp.types as types +from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import aconnect_sse + +logger = logging.getLogger(__name__) + + +def remove_request_params(url: str) -> str: + return urljoin(url, urlparse(url).path) + + +@asynccontextmanager +async def sse_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, + verify: ssl.SSLContext | str | bool = True, +): + """ + Client transport for SSE. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + """ + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx.AsyncClient(headers=headers, verify=verify) as client: + async with aconnect_sse( + client, + "GET", + url, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + async def sse_reader( + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info( + f"Received endpoint URL: {endpoint_url}" + ) + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme + != endpoint_parsed.scheme + ): + error_msg = ( + "Endpoint origin does not match " + f"connection origin: {endpoint_url}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + task_status.started(endpoint_url) + + case "message": + try: + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data + ) + logger.debug( + f"Received server message: {message}" + ) + except Exception as exc: + logger.error( + f"Error parsing server message: {exc}" + ) + await read_stream_writer.send(exc) + continue + + await read_stream_writer.send(message) + case _: + logger.warning( + f"Unknown SSE event: {sse.event}" + ) + except Exception as exc: + logger.error(f"Error in sse_reader: {exc}") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader: + async for message in write_stream_reader: + logger.debug(f"Sending client message: {message}") + response = await client.post( + endpoint_url, + json=message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug( + "Client message sent successfully: " + f"{response.status_code}" + ) + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await write_stream.aclose() + + endpoint_url = await tg.start(sse_reader) + logger.info( + f"Starting post writer with endpoint URL: {endpoint_url}" + ) + tg.start_soon(post_writer, endpoint_url) + + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/packages/dbgpt-core/src/dbgpt/model/proxy/llms/chatgpt.py b/packages/dbgpt-core/src/dbgpt/model/proxy/llms/chatgpt.py index a7cfe57ee..c6ee32683 100755 --- a/packages/dbgpt-core/src/dbgpt/model/proxy/llms/chatgpt.py +++ b/packages/dbgpt-core/src/dbgpt/model/proxy/llms/chatgpt.py @@ -23,7 +23,7 @@ from dbgpt.model.utils.chatgpt_utils import OpenAIParameters from dbgpt.util.i18n_utils import _ if TYPE_CHECKING: - from httpx._types import ProxiesTypes + from httpx._types import ProxiesTypes, ProxyTypes from openai import AsyncAzureOpenAI, AsyncOpenAI ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] @@ -139,6 +139,7 @@ class OpenAILLMClient(ProxyLLMClient): api_version: Optional[str] = None, model: Optional[str] = None, proxies: Optional["ProxiesTypes"] = None, + proxy: Optional["ProxyTypes"] = None, timeout: Optional[int] = 240, model_alias: Optional[str] = "gpt-4o-mini", context_length: Optional[int] = 8192, @@ -160,6 +161,7 @@ class OpenAILLMClient(ProxyLLMClient): api_key=self._resolve_env_vars(api_key), api_version=self._resolve_env_vars(api_version), proxies=proxies, + proxy=proxy, full_url=kwargs.get("full_url"), ) @@ -203,7 +205,7 @@ class OpenAILLMClient(ProxyLLMClient): api_type=model_params.api_type, api_version=model_params.api_version, model=model_params.real_provider_model_name, - proxies=model_params.http_proxy, + proxy=model_params.http_proxy, model_alias=model_params.real_provider_model_name, context_length=max(model_params.context_length or 8192, 8192), # full_url=model_params.proxy_server_url, diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py index 453982fb7..418178c92 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/resource/mcp.py @@ -50,12 +50,14 @@ class MCPPackResourceParameters(PackResourceParameters): class MCPSSEToolPack(MCPToolPack): def __init__(self, mcp_servers: Union[str, List[str]], **kwargs): """Initialize the MCPSSEToolPack with the given MCP servers.""" + import ssl + headers = {} + # token is not supported in sse mode + servers = ( + mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers + ) if "token" in kwargs and kwargs["token"]: - # token is not supported in sse mode - servers = ( - mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers - ) tokens = ( kwargs["token"].split(";") if isinstance(kwargs["token"], str) @@ -69,7 +71,33 @@ class MCPSSEToolPack(MCPToolPack): for server in servers: headers[server] = {"Authorization": f"Bearer {token}"} kwargs.pop("token") - super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs) + ssl_verify = True + ssl_verify_map = {} + if "no_ssl_verify" in kwargs: + if kwargs["no_ssl_verify"] is True: + ssl_verify = False + kwargs.pop("no_ssl_verify") + if ssl_verify is True and "ssl_ca_cert" in kwargs: + ssl_ca_certs = ( + kwargs["ssl_ca_cert"].split(";") + if isinstance(kwargs["ssl_ca_cert"], str) + else kwargs["ssl_ca_cert"] + ) + if len(servers) == len(ssl_ca_certs): + for i, ssl_ca_cert in enumerate(ssl_ca_certs): + ssl_verify_map[servers[i]] = ssl.create_default_context( + cafile=ssl_ca_cert + ) + else: + ssl_ca_cert = ssl_ca_certs[0] + for server in servers: + ssl_verify_map[server] = ssl.create_default_context( + cafile=ssl_ca_cert + ) + verify = ssl_verify_map if ssl_verify_map else ssl_verify + super().__init__( + mcp_servers=mcp_servers, headers=headers, ssl_verify=verify, **kwargs + ) @classmethod def type_alias(cls) -> str: @@ -97,5 +125,20 @@ class MCPSSEToolPack(MCPToolPack): "tags": "privacy", }, ) + no_ssl_verify: bool = dataclasses.field( + default=False, + metadata={ + "help": _( + "Disable SSL verification. " + "This is not recommended for production use." + ), + }, + ) + ssl_ca_cert: Optional[str] = dataclasses.field( + default=None, + metadata={ + "help": _("Path to the CA certificate file. split by ';' "), + }, + ) return _DynMCPSSEPackResourceParameters