mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
feat(agent): Support SSL/TLS for MCP (#2591)
This commit is contained in:
parent
0fd578cf87
commit
e9ce534ca1
@ -180,7 +180,7 @@ class ReActAction(ToolAction):
|
||||
# Try to parse the action input to dict
|
||||
if action_input and isinstance(action_input, str):
|
||||
tool_args = parse_or_raise_error(action_input)
|
||||
elif isinstance(action_input, dict):
|
||||
elif isinstance(action_input, dict) or isinstance(action_input, list):
|
||||
tool_args = action_input
|
||||
action_input_str = json.dumps(action_input, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
|
@ -138,6 +138,10 @@ async def run_tool(
|
||||
if parsed_args and isinstance(parsed_args, tuple):
|
||||
args = parsed_args[1]
|
||||
|
||||
if args is not None and isinstance(args, list) and len(args) == 0:
|
||||
# Input args is empty list, just use default args
|
||||
args = {}
|
||||
|
||||
try:
|
||||
tool_result = await tool_pack.async_execute(resource_name=name, **args)
|
||||
status = Status.COMPLETE.value
|
||||
|
@ -2,13 +2,14 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from dbgpt.util.json_utils import parse_or_raise_error
|
||||
|
||||
from ...util.mcp_utils import sse_client
|
||||
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
|
||||
from ..pack import Resource, ResourcePack
|
||||
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
|
||||
@ -66,6 +67,8 @@ def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
|
||||
# The position arguments is empty
|
||||
args = ()
|
||||
kwargs = parse_or_raise_error(input_str)
|
||||
if kwargs is not None and isinstance(kwargs, list) and len(kwargs) == 0:
|
||||
kwargs = {}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
@ -303,6 +306,37 @@ class MCPToolPack(ToolPack):
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
If you want to set the ssl verify, you can use the ssl_verify parameter:
|
||||
.. code-block:: python
|
||||
|
||||
# Default ssl_verify is True
|
||||
tools = MCPToolPack(
|
||||
"https://your_ssl_domain/sse",
|
||||
)
|
||||
|
||||
# Set the default ssl_verify to False to disable ssl verify
|
||||
tools2 = MCPToolPack(
|
||||
"https://your_ssl_domain/sse", default_ssl_verify=False
|
||||
)
|
||||
|
||||
# With Custom CA file
|
||||
tools3 = MCPToolPack(
|
||||
"https://your_ssl_domain/sse", default_ssl_cafile="/path/to/your/ca.crt"
|
||||
)
|
||||
|
||||
# Set the ssl_verify for each server
|
||||
import ssl
|
||||
|
||||
tools4 = MCPToolPack(
|
||||
"https://your_ssl_domain/sse",
|
||||
ssl_verify={
|
||||
"https://your_ssl_domain/sse": ssl.create_default_context(
|
||||
cafile="/path/to/your/ca.crt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -310,6 +344,9 @@ class MCPToolPack(ToolPack):
|
||||
mcp_servers: Union[str, List[str]],
|
||||
headers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
default_headers: Optional[Dict[str, Any]] = None,
|
||||
ssl_verify: Optional[Dict[str, Union[ssl.SSLContext, str, bool]]] = None,
|
||||
default_ssl_verify: Union[ssl.SSLContext, str, bool] = True,
|
||||
default_ssl_cafile: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create an Auto-GPT plugin tool pack."""
|
||||
@ -320,6 +357,12 @@ class MCPToolPack(ToolPack):
|
||||
self._default_headers = default_headers or {}
|
||||
self._headers_map = headers or {}
|
||||
self.server_headers_map = {}
|
||||
if default_ssl_cafile and not ssl_verify and default_ssl_verify:
|
||||
default_ssl_verify = ssl.create_default_context(cafile=default_ssl_cafile)
|
||||
|
||||
self._default_ssl_verify = default_ssl_verify
|
||||
self._ssl_verify_map = ssl_verify or {}
|
||||
self.server_ssl_verify_map = {}
|
||||
|
||||
def switch_mcp_input_schema(self, input_schema: dict):
|
||||
args = {}
|
||||
@ -362,8 +405,14 @@ class MCPToolPack(ToolPack):
|
||||
for server in server_list:
|
||||
server_headers = self._headers_map.get(server, self._default_headers)
|
||||
self.server_headers_map[server] = server_headers
|
||||
server_ssl_verify = self._ssl_verify_map.get(
|
||||
server, self._default_ssl_verify
|
||||
)
|
||||
self.server_ssl_verify_map[server] = server_ssl_verify
|
||||
|
||||
async with sse_client(url=server, headers=server_headers) as (read, write):
|
||||
async with sse_client(
|
||||
url=server, headers=server_headers, verify=server_ssl_verify
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
@ -378,8 +427,13 @@ class MCPToolPack(ToolPack):
|
||||
):
|
||||
try:
|
||||
headers_to_use = self.server_headers_map.get(server, {})
|
||||
ssl_verify_to_use = self.server_ssl_verify_map.get(
|
||||
server, True
|
||||
)
|
||||
async with sse_client(
|
||||
url=server, headers=headers_to_use
|
||||
url=server,
|
||||
headers=headers_to_use,
|
||||
verify=ssl_verify_to_use,
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
|
147
packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py
Normal file
147
packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py
Normal file
@ -0,0 +1,147 @@
|
||||
import logging
|
||||
import ssl
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import mcp.types as types
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient(headers=headers, verify=verify) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = (
|
||||
"Endpoint origin does not match "
|
||||
f"connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
case _:
|
||||
logger.warning(
|
||||
f"Unknown SSE event: {sse.event}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
"Client message sent successfully: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
@ -23,7 +23,7 @@ from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx._types import ProxiesTypes
|
||||
from httpx._types import ProxiesTypes, ProxyTypes
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
|
||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||
@ -139,6 +139,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
api_version: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
proxies: Optional["ProxiesTypes"] = None,
|
||||
proxy: Optional["ProxyTypes"] = None,
|
||||
timeout: Optional[int] = 240,
|
||||
model_alias: Optional[str] = "gpt-4o-mini",
|
||||
context_length: Optional[int] = 8192,
|
||||
@ -160,6 +161,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
api_key=self._resolve_env_vars(api_key),
|
||||
api_version=self._resolve_env_vars(api_version),
|
||||
proxies=proxies,
|
||||
proxy=proxy,
|
||||
full_url=kwargs.get("full_url"),
|
||||
)
|
||||
|
||||
@ -203,7 +205,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
||||
api_type=model_params.api_type,
|
||||
api_version=model_params.api_version,
|
||||
model=model_params.real_provider_model_name,
|
||||
proxies=model_params.http_proxy,
|
||||
proxy=model_params.http_proxy,
|
||||
model_alias=model_params.real_provider_model_name,
|
||||
context_length=max(model_params.context_length or 8192, 8192),
|
||||
# full_url=model_params.proxy_server_url,
|
||||
|
@ -50,12 +50,14 @@ class MCPPackResourceParameters(PackResourceParameters):
|
||||
class MCPSSEToolPack(MCPToolPack):
|
||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||
"""Initialize the MCPSSEToolPack with the given MCP servers."""
|
||||
import ssl
|
||||
|
||||
headers = {}
|
||||
# token is not supported in sse mode
|
||||
servers = (
|
||||
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
||||
)
|
||||
if "token" in kwargs and kwargs["token"]:
|
||||
# token is not supported in sse mode
|
||||
servers = (
|
||||
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
||||
)
|
||||
tokens = (
|
||||
kwargs["token"].split(";")
|
||||
if isinstance(kwargs["token"], str)
|
||||
@ -69,7 +71,33 @@ class MCPSSEToolPack(MCPToolPack):
|
||||
for server in servers:
|
||||
headers[server] = {"Authorization": f"Bearer {token}"}
|
||||
kwargs.pop("token")
|
||||
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
|
||||
ssl_verify = True
|
||||
ssl_verify_map = {}
|
||||
if "no_ssl_verify" in kwargs:
|
||||
if kwargs["no_ssl_verify"] is True:
|
||||
ssl_verify = False
|
||||
kwargs.pop("no_ssl_verify")
|
||||
if ssl_verify is True and "ssl_ca_cert" in kwargs:
|
||||
ssl_ca_certs = (
|
||||
kwargs["ssl_ca_cert"].split(";")
|
||||
if isinstance(kwargs["ssl_ca_cert"], str)
|
||||
else kwargs["ssl_ca_cert"]
|
||||
)
|
||||
if len(servers) == len(ssl_ca_certs):
|
||||
for i, ssl_ca_cert in enumerate(ssl_ca_certs):
|
||||
ssl_verify_map[servers[i]] = ssl.create_default_context(
|
||||
cafile=ssl_ca_cert
|
||||
)
|
||||
else:
|
||||
ssl_ca_cert = ssl_ca_certs[0]
|
||||
for server in servers:
|
||||
ssl_verify_map[server] = ssl.create_default_context(
|
||||
cafile=ssl_ca_cert
|
||||
)
|
||||
verify = ssl_verify_map if ssl_verify_map else ssl_verify
|
||||
super().__init__(
|
||||
mcp_servers=mcp_servers, headers=headers, ssl_verify=verify, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
@ -97,5 +125,20 @@ class MCPSSEToolPack(MCPToolPack):
|
||||
"tags": "privacy",
|
||||
},
|
||||
)
|
||||
no_ssl_verify: bool = dataclasses.field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Disable SSL verification. "
|
||||
"This is not recommended for production use."
|
||||
),
|
||||
},
|
||||
)
|
||||
ssl_ca_cert: Optional[str] = dataclasses.field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("Path to the CA certificate file. split by ';' "),
|
||||
},
|
||||
)
|
||||
|
||||
return _DynMCPSSEPackResourceParameters
|
||||
|
Loading…
Reference in New Issue
Block a user