mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
feat(agent): Support SSL/TLS for MCP (#2591)
This commit is contained in:
parent
0fd578cf87
commit
e9ce534ca1
@ -180,7 +180,7 @@ class ReActAction(ToolAction):
|
|||||||
# Try to parse the action input to dict
|
# Try to parse the action input to dict
|
||||||
if action_input and isinstance(action_input, str):
|
if action_input and isinstance(action_input, str):
|
||||||
tool_args = parse_or_raise_error(action_input)
|
tool_args = parse_or_raise_error(action_input)
|
||||||
elif isinstance(action_input, dict):
|
elif isinstance(action_input, dict) or isinstance(action_input, list):
|
||||||
tool_args = action_input
|
tool_args = action_input
|
||||||
action_input_str = json.dumps(action_input, ensure_ascii=False)
|
action_input_str = json.dumps(action_input, ensure_ascii=False)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
@ -138,6 +138,10 @@ async def run_tool(
|
|||||||
if parsed_args and isinstance(parsed_args, tuple):
|
if parsed_args and isinstance(parsed_args, tuple):
|
||||||
args = parsed_args[1]
|
args = parsed_args[1]
|
||||||
|
|
||||||
|
if args is not None and isinstance(args, list) and len(args) == 0:
|
||||||
|
# Input args is empty list, just use default args
|
||||||
|
args = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_result = await tool_pack.async_execute(resource_name=name, **args)
|
tool_result = await tool_pack.async_execute(resource_name=name, **args)
|
||||||
status = Status.COMPLETE.value
|
status = Status.COMPLETE.value
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union, cast
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from dbgpt.util.json_utils import parse_or_raise_error
|
from dbgpt.util.json_utils import parse_or_raise_error
|
||||||
|
|
||||||
|
from ...util.mcp_utils import sse_client
|
||||||
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
|
from ..base import EXECUTE_ARGS_TYPE, PARSE_EXECUTE_ARGS_FUNCTION, ResourceType, T
|
||||||
from ..pack import Resource, ResourcePack
|
from ..pack import Resource, ResourcePack
|
||||||
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
|
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
|
||||||
@ -66,6 +67,8 @@ def json_parse_execute_args_func(input_str: str) -> Optional[EXECUTE_ARGS_TYPE]:
|
|||||||
# The position arguments is empty
|
# The position arguments is empty
|
||||||
args = ()
|
args = ()
|
||||||
kwargs = parse_or_raise_error(input_str)
|
kwargs = parse_or_raise_error(input_str)
|
||||||
|
if kwargs is not None and isinstance(kwargs, list) and len(kwargs) == 0:
|
||||||
|
kwargs = {}
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
@ -303,6 +306,37 @@ class MCPToolPack(ToolPack):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
If you want to set the ssl verify, you can use the ssl_verify parameter:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Default ssl_verify is True
|
||||||
|
tools = MCPToolPack(
|
||||||
|
"https://your_ssl_domain/sse",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the default ssl_verify to False to disable ssl verify
|
||||||
|
tools2 = MCPToolPack(
|
||||||
|
"https://your_ssl_domain/sse", default_ssl_verify=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# With Custom CA file
|
||||||
|
tools3 = MCPToolPack(
|
||||||
|
"https://your_ssl_domain/sse", default_ssl_cafile="/path/to/your/ca.crt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the ssl_verify for each server
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
tools4 = MCPToolPack(
|
||||||
|
"https://your_ssl_domain/sse",
|
||||||
|
ssl_verify={
|
||||||
|
"https://your_ssl_domain/sse": ssl.create_default_context(
|
||||||
|
cafile="/path/to/your/ca.crt"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -310,6 +344,9 @@ class MCPToolPack(ToolPack):
|
|||||||
mcp_servers: Union[str, List[str]],
|
mcp_servers: Union[str, List[str]],
|
||||||
headers: Optional[Dict[str, Dict[str, Any]]] = None,
|
headers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
default_headers: Optional[Dict[str, Any]] = None,
|
default_headers: Optional[Dict[str, Any]] = None,
|
||||||
|
ssl_verify: Optional[Dict[str, Union[ssl.SSLContext, str, bool]]] = None,
|
||||||
|
default_ssl_verify: Union[ssl.SSLContext, str, bool] = True,
|
||||||
|
default_ssl_cafile: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Create an Auto-GPT plugin tool pack."""
|
"""Create an Auto-GPT plugin tool pack."""
|
||||||
@ -320,6 +357,12 @@ class MCPToolPack(ToolPack):
|
|||||||
self._default_headers = default_headers or {}
|
self._default_headers = default_headers or {}
|
||||||
self._headers_map = headers or {}
|
self._headers_map = headers or {}
|
||||||
self.server_headers_map = {}
|
self.server_headers_map = {}
|
||||||
|
if default_ssl_cafile and not ssl_verify and default_ssl_verify:
|
||||||
|
default_ssl_verify = ssl.create_default_context(cafile=default_ssl_cafile)
|
||||||
|
|
||||||
|
self._default_ssl_verify = default_ssl_verify
|
||||||
|
self._ssl_verify_map = ssl_verify or {}
|
||||||
|
self.server_ssl_verify_map = {}
|
||||||
|
|
||||||
def switch_mcp_input_schema(self, input_schema: dict):
|
def switch_mcp_input_schema(self, input_schema: dict):
|
||||||
args = {}
|
args = {}
|
||||||
@ -362,8 +405,14 @@ class MCPToolPack(ToolPack):
|
|||||||
for server in server_list:
|
for server in server_list:
|
||||||
server_headers = self._headers_map.get(server, self._default_headers)
|
server_headers = self._headers_map.get(server, self._default_headers)
|
||||||
self.server_headers_map[server] = server_headers
|
self.server_headers_map[server] = server_headers
|
||||||
|
server_ssl_verify = self._ssl_verify_map.get(
|
||||||
|
server, self._default_ssl_verify
|
||||||
|
)
|
||||||
|
self.server_ssl_verify_map[server] = server_ssl_verify
|
||||||
|
|
||||||
async with sse_client(url=server, headers=server_headers) as (read, write):
|
async with sse_client(
|
||||||
|
url=server, headers=server_headers, verify=server_ssl_verify
|
||||||
|
) as (read, write):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
# Initialize the connection
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
@ -378,8 +427,13 @@ class MCPToolPack(ToolPack):
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
headers_to_use = self.server_headers_map.get(server, {})
|
headers_to_use = self.server_headers_map.get(server, {})
|
||||||
|
ssl_verify_to_use = self.server_ssl_verify_map.get(
|
||||||
|
server, True
|
||||||
|
)
|
||||||
async with sse_client(
|
async with sse_client(
|
||||||
url=server, headers=headers_to_use
|
url=server,
|
||||||
|
headers=headers_to_use,
|
||||||
|
verify=ssl_verify_to_use,
|
||||||
) as (read, write):
|
) as (read, write):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
# Initialize the connection
|
||||||
|
147
packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py
Normal file
147
packages/dbgpt-core/src/dbgpt/agent/util/mcp_utils.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import httpx
|
||||||
|
import mcp.types as types
|
||||||
|
from anyio.abc import TaskStatus
|
||||||
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
from httpx_sse import aconnect_sse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_request_params(url: str) -> str:
|
||||||
|
return urljoin(url, urlparse(url).path)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def sse_client(
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, Any] | None = None,
|
||||||
|
timeout: float = 5,
|
||||||
|
sse_read_timeout: float = 60 * 5,
|
||||||
|
verify: ssl.SSLContext | str | bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Client transport for SSE.
|
||||||
|
|
||||||
|
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||||
|
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||||
|
"""
|
||||||
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||||
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||||
|
|
||||||
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||||
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||||
|
|
||||||
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||||
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
try:
|
||||||
|
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||||
|
async with httpx.AsyncClient(headers=headers, verify=verify) as client:
|
||||||
|
async with aconnect_sse(
|
||||||
|
client,
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||||
|
) as event_source:
|
||||||
|
event_source.response.raise_for_status()
|
||||||
|
logger.debug("SSE connection established")
|
||||||
|
|
||||||
|
async def sse_reader(
|
||||||
|
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
async for sse in event_source.aiter_sse():
|
||||||
|
logger.debug(f"Received SSE event: {sse.event}")
|
||||||
|
match sse.event:
|
||||||
|
case "endpoint":
|
||||||
|
endpoint_url = urljoin(url, sse.data)
|
||||||
|
logger.info(
|
||||||
|
f"Received endpoint URL: {endpoint_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
url_parsed = urlparse(url)
|
||||||
|
endpoint_parsed = urlparse(endpoint_url)
|
||||||
|
if (
|
||||||
|
url_parsed.netloc != endpoint_parsed.netloc
|
||||||
|
or url_parsed.scheme
|
||||||
|
!= endpoint_parsed.scheme
|
||||||
|
):
|
||||||
|
error_msg = (
|
||||||
|
"Endpoint origin does not match "
|
||||||
|
f"connection origin: {endpoint_url}"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
task_status.started(endpoint_url)
|
||||||
|
|
||||||
|
case "message":
|
||||||
|
try:
|
||||||
|
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||||
|
sse.data
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Received server message: {message}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
f"Error parsing server message: {exc}"
|
||||||
|
)
|
||||||
|
await read_stream_writer.send(exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await read_stream_writer.send(message)
|
||||||
|
case _:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown SSE event: {sse.event}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Error in sse_reader: {exc}")
|
||||||
|
await read_stream_writer.send(exc)
|
||||||
|
finally:
|
||||||
|
await read_stream_writer.aclose()
|
||||||
|
|
||||||
|
async def post_writer(endpoint_url: str):
|
||||||
|
try:
|
||||||
|
async with write_stream_reader:
|
||||||
|
async for message in write_stream_reader:
|
||||||
|
logger.debug(f"Sending client message: {message}")
|
||||||
|
response = await client.post(
|
||||||
|
endpoint_url,
|
||||||
|
json=message.model_dump(
|
||||||
|
by_alias=True,
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.debug(
|
||||||
|
"Client message sent successfully: "
|
||||||
|
f"{response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Error in post_writer: {exc}")
|
||||||
|
finally:
|
||||||
|
await write_stream.aclose()
|
||||||
|
|
||||||
|
endpoint_url = await tg.start(sse_reader)
|
||||||
|
logger.info(
|
||||||
|
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||||
|
)
|
||||||
|
tg.start_soon(post_writer, endpoint_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield read_stream, write_stream
|
||||||
|
finally:
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
|
finally:
|
||||||
|
await read_stream_writer.aclose()
|
||||||
|
await write_stream.aclose()
|
@ -23,7 +23,7 @@ from dbgpt.model.utils.chatgpt_utils import OpenAIParameters
|
|||||||
from dbgpt.util.i18n_utils import _
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from httpx._types import ProxiesTypes
|
from httpx._types import ProxiesTypes, ProxyTypes
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI]
|
||||||
@ -139,6 +139,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
|||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
proxies: Optional["ProxiesTypes"] = None,
|
proxies: Optional["ProxiesTypes"] = None,
|
||||||
|
proxy: Optional["ProxyTypes"] = None,
|
||||||
timeout: Optional[int] = 240,
|
timeout: Optional[int] = 240,
|
||||||
model_alias: Optional[str] = "gpt-4o-mini",
|
model_alias: Optional[str] = "gpt-4o-mini",
|
||||||
context_length: Optional[int] = 8192,
|
context_length: Optional[int] = 8192,
|
||||||
@ -160,6 +161,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
|||||||
api_key=self._resolve_env_vars(api_key),
|
api_key=self._resolve_env_vars(api_key),
|
||||||
api_version=self._resolve_env_vars(api_version),
|
api_version=self._resolve_env_vars(api_version),
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
|
proxy=proxy,
|
||||||
full_url=kwargs.get("full_url"),
|
full_url=kwargs.get("full_url"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,7 +205,7 @@ class OpenAILLMClient(ProxyLLMClient):
|
|||||||
api_type=model_params.api_type,
|
api_type=model_params.api_type,
|
||||||
api_version=model_params.api_version,
|
api_version=model_params.api_version,
|
||||||
model=model_params.real_provider_model_name,
|
model=model_params.real_provider_model_name,
|
||||||
proxies=model_params.http_proxy,
|
proxy=model_params.http_proxy,
|
||||||
model_alias=model_params.real_provider_model_name,
|
model_alias=model_params.real_provider_model_name,
|
||||||
context_length=max(model_params.context_length or 8192, 8192),
|
context_length=max(model_params.context_length or 8192, 8192),
|
||||||
# full_url=model_params.proxy_server_url,
|
# full_url=model_params.proxy_server_url,
|
||||||
|
@ -50,12 +50,14 @@ class MCPPackResourceParameters(PackResourceParameters):
|
|||||||
class MCPSSEToolPack(MCPToolPack):
|
class MCPSSEToolPack(MCPToolPack):
|
||||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||||
"""Initialize the MCPSSEToolPack with the given MCP servers."""
|
"""Initialize the MCPSSEToolPack with the given MCP servers."""
|
||||||
|
import ssl
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
|
# token is not supported in sse mode
|
||||||
|
servers = (
|
||||||
|
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
||||||
|
)
|
||||||
if "token" in kwargs and kwargs["token"]:
|
if "token" in kwargs and kwargs["token"]:
|
||||||
# token is not supported in sse mode
|
|
||||||
servers = (
|
|
||||||
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
|
||||||
)
|
|
||||||
tokens = (
|
tokens = (
|
||||||
kwargs["token"].split(";")
|
kwargs["token"].split(";")
|
||||||
if isinstance(kwargs["token"], str)
|
if isinstance(kwargs["token"], str)
|
||||||
@ -69,7 +71,33 @@ class MCPSSEToolPack(MCPToolPack):
|
|||||||
for server in servers:
|
for server in servers:
|
||||||
headers[server] = {"Authorization": f"Bearer {token}"}
|
headers[server] = {"Authorization": f"Bearer {token}"}
|
||||||
kwargs.pop("token")
|
kwargs.pop("token")
|
||||||
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
|
ssl_verify = True
|
||||||
|
ssl_verify_map = {}
|
||||||
|
if "no_ssl_verify" in kwargs:
|
||||||
|
if kwargs["no_ssl_verify"] is True:
|
||||||
|
ssl_verify = False
|
||||||
|
kwargs.pop("no_ssl_verify")
|
||||||
|
if ssl_verify is True and "ssl_ca_cert" in kwargs:
|
||||||
|
ssl_ca_certs = (
|
||||||
|
kwargs["ssl_ca_cert"].split(";")
|
||||||
|
if isinstance(kwargs["ssl_ca_cert"], str)
|
||||||
|
else kwargs["ssl_ca_cert"]
|
||||||
|
)
|
||||||
|
if len(servers) == len(ssl_ca_certs):
|
||||||
|
for i, ssl_ca_cert in enumerate(ssl_ca_certs):
|
||||||
|
ssl_verify_map[servers[i]] = ssl.create_default_context(
|
||||||
|
cafile=ssl_ca_cert
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ssl_ca_cert = ssl_ca_certs[0]
|
||||||
|
for server in servers:
|
||||||
|
ssl_verify_map[server] = ssl.create_default_context(
|
||||||
|
cafile=ssl_ca_cert
|
||||||
|
)
|
||||||
|
verify = ssl_verify_map if ssl_verify_map else ssl_verify
|
||||||
|
super().__init__(
|
||||||
|
mcp_servers=mcp_servers, headers=headers, ssl_verify=verify, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def type_alias(cls) -> str:
|
def type_alias(cls) -> str:
|
||||||
@ -97,5 +125,20 @@ class MCPSSEToolPack(MCPToolPack):
|
|||||||
"tags": "privacy",
|
"tags": "privacy",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
no_ssl_verify: bool = dataclasses.field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Disable SSL verification. "
|
||||||
|
"This is not recommended for production use."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ssl_ca_cert: Optional[str] = dataclasses.field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("Path to the CA certificate file. split by ';' "),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return _DynMCPSSEPackResourceParameters
|
return _DynMCPSSEPackResourceParameters
|
||||||
|
Loading…
Reference in New Issue
Block a user