feat(agent): Support SSL/TLS for MCP (#2591)

This commit is contained in:
Fangyin Cheng 2025-04-08 08:47:41 +08:00 committed by GitHub
parent 0fd578cf87
commit e9ce534ca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 261 additions and 11 deletions

View File

@ -180,7 +180,7 @@ class ReActAction(ToolAction):
# Try to parse the action input to dict
if action_input and isinstance(action_input, str):
tool_args = parse_or_raise_error(action_input)
elif isinstance(action_input, dict):
elif isinstance(action_input, dict) or isinstance(action_input, list):
tool_args = action_input
action_input_str = json.dumps(action_input, ensure_ascii=False)
except json.JSONDecodeError:

View File

@ -138,6 +138,10 @@ async def run_tool(
if parsed_args and isinstance(parsed_args, tuple):
args = parsed_args[1]
if args is not None and isinstance(args, list) and len(args) == 0:
# Input args is empty list, just use default args
args = {}
try:
tool_result = await tool_pack.async_execute(resource_name=name, **args)
status = Status.COMPLETE.value

View File

@ -2,13 +2,14 @@
import logging
import os
import ssl
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
from mcp import ClientSession
from mcp.client.sse import sse_client
from dbgpt.util.json_utils import parse_or_raise_error
from ...util.mcp_utils import sse_client
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
from ..pack import Resource, ResourcePack
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
@ -66,6 +67,8 @@ def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
# The position arguments is empty
args = ()
kwargs = parse_or_raise_error(input_str)
if kwargs is not None and isinstance(kwargs, list) and len(kwargs) == 0:
kwargs = {}
return args, kwargs
@ -303,6 +306,37 @@ class MCPToolPack(ToolPack):
}
}
)
If you want to set the ssl verify, you can use the ssl_verify parameter:
.. code-block:: python
# Default ssl_verify is True
tools = MCPToolPack(
"https://your_ssl_domain/sse",
)
# Set the default ssl_verify to False to disable ssl verify
tools2 = MCPToolPack(
"https://your_ssl_domain/sse", default_ssl_verify=False
)
# With Custom CA file
tools3 = MCPToolPack(
"https://your_ssl_domain/sse", default_ssl_cafile="/path/to/your/ca.crt"
)
# Set the ssl_verify for each server
import ssl
tools4 = MCPToolPack(
"https://your_ssl_domain/sse",
ssl_verify={
"https://your_ssl_domain/sse": ssl.create_default_context(
cafile="/path/to/your/ca.crt"
),
},
)
"""
def __init__(
@ -310,6 +344,9 @@ class MCPToolPack(ToolPack):
mcp_servers: Union[str, List[str]],
headers: Optional[Dict[str, Dict[str, Any]]] = None,
default_headers: Optional[Dict[str, Any]] = None,
ssl_verify: Optional[Dict[str, Union[ssl.SSLContext, str, bool]]] = None,
default_ssl_verify: Union[ssl.SSLContext, str, bool] = True,
default_ssl_cafile: Optional[str] = None,
**kwargs,
):
"""Create an Auto-GPT plugin tool pack."""
@ -320,6 +357,12 @@ class MCPToolPack(ToolPack):
self._default_headers = default_headers or {}
self._headers_map = headers or {}
self.server_headers_map = {}
if default_ssl_cafile and not ssl_verify and default_ssl_verify:
default_ssl_verify = ssl.create_default_context(cafile=default_ssl_cafile)
self._default_ssl_verify = default_ssl_verify
self._ssl_verify_map = ssl_verify or {}
self.server_ssl_verify_map = {}
def switch_mcp_input_schema(self, input_schema: dict):
args = {}
@ -362,8 +405,14 @@ class MCPToolPack(ToolPack):
for server in server_list:
server_headers = self._headers_map.get(server, self._default_headers)
self.server_headers_map[server] = server_headers
server_ssl_verify = self._ssl_verify_map.get(
server, self._default_ssl_verify
)
self.server_ssl_verify_map[server] = server_ssl_verify
async with sse_client(url=server, headers=server_headers) as (read, write):
async with sse_client(
url=server, headers=server_headers, verify=server_ssl_verify
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
@ -378,8 +427,13 @@ class MCPToolPack(ToolPack):
):
try:
headers_to_use = self.server_headers_map.get(server, {})
ssl_verify_to_use = self.server_ssl_verify_map.get(
server, True
)
async with sse_client(
url=server, headers=headers_to_use
url=server,
headers=headers_to_use,
verify=ssl_verify_to_use,
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection

View File

@ -0,0 +1,147 @@
import logging
import ssl
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
import mcp.types as types
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
verify: ssl.SSLContext | str | bool = True,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers, verify=verify) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
task_status.started(endpoint_url)
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

View File

@ -23,7 +23,7 @@ from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
from dbgpt.util.i18n_utils import _
if TYPE_CHECKING:
from httpx._types import ProxiesTypes
from httpx._types import ProxiesTypes, ProxyTypes
from openai import AsyncAzureOpenAI, AsyncOpenAI
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
@ -139,6 +139,7 @@ class OpenAILLMClient(ProxyLLMClient):
api_version: Optional[str] = None,
model: Optional[str] = None,
proxies: Optional["ProxiesTypes"] = None,
proxy: Optional["ProxyTypes"] = None,
timeout: Optional[int] = 240,
model_alias: Optional[str] = "gpt-4o-mini",
context_length: Optional[int] = 8192,
@ -160,6 +161,7 @@ class OpenAILLMClient(ProxyLLMClient):
api_key=self._resolve_env_vars(api_key),
api_version=self._resolve_env_vars(api_version),
proxies=proxies,
proxy=proxy,
full_url=kwargs.get("full_url"),
)
@ -203,7 +205,7 @@ class OpenAILLMClient(ProxyLLMClient):
api_type=model_params.api_type,
api_version=model_params.api_version,
model=model_params.real_provider_model_name,
proxies=model_params.http_proxy,
proxy=model_params.http_proxy,
model_alias=model_params.real_provider_model_name,
context_length=max(model_params.context_length or 8192, 8192),
# full_url=model_params.proxy_server_url,

View File

@ -50,12 +50,14 @@ class MCPPackResourceParameters(PackResourceParameters):
class MCPSSEToolPack(MCPToolPack):
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
"""Initialize the MCPSSEToolPack with the given MCP servers."""
import ssl
headers = {}
# token is not supported in sse mode
servers = (
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
)
if "token" in kwargs and kwargs["token"]:
# token is not supported in sse mode
servers = (
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
)
tokens = (
kwargs["token"].split(";")
if isinstance(kwargs["token"], str)
@ -69,7 +71,33 @@ class MCPSSEToolPack(MCPToolPack):
for server in servers:
headers[server] = {"Authorization": f"Bearer {token}"}
kwargs.pop("token")
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
ssl_verify = True
ssl_verify_map = {}
if "no_ssl_verify" in kwargs:
if kwargs["no_ssl_verify"] is True:
ssl_verify = False
kwargs.pop("no_ssl_verify")
if ssl_verify is True and "ssl_ca_cert" in kwargs:
ssl_ca_certs = (
kwargs["ssl_ca_cert"].split(";")
if isinstance(kwargs["ssl_ca_cert"], str)
else kwargs["ssl_ca_cert"]
)
if len(servers) == len(ssl_ca_certs):
for i, ssl_ca_cert in enumerate(ssl_ca_certs):
ssl_verify_map[servers[i]] = ssl.create_default_context(
cafile=ssl_ca_cert
)
else:
ssl_ca_cert = ssl_ca_certs[0]
for server in servers:
ssl_verify_map[server] = ssl.create_default_context(
cafile=ssl_ca_cert
)
verify = ssl_verify_map if ssl_verify_map else ssl_verify
super().__init__(
mcp_servers=mcp_servers, headers=headers, ssl_verify=verify, **kwargs
)
@classmethod
def type_alias(cls) -> str:
@ -97,5 +125,20 @@ class MCPSSEToolPack(MCPToolPack):
"tags": "privacy",
},
)
no_ssl_verify: bool = dataclasses.field(
default=False,
metadata={
"help": _(
"Disable SSL verification. "
"This is not recommended for production use."
),
},
)
ssl_ca_cert: Optional[str] = dataclasses.field(
default=None,
metadata={
"help": _("Path to the CA certificate file. split by ';' "),
},
)
return _DynMCPSSEPackResourceParameters