style(core): fix style in langchain_core/_security (#38189)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2026-06-16 17:40:50 +02:00
committed by GitHub
parent 0fdcdad393
commit fc956c8680
5 changed files with 37 additions and 30 deletions

View File

@@ -114,6 +114,9 @@ class SSRFPolicy:
] = ()
DEFAULT_SSRF_POLICY = SSRFPolicy()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -196,7 +199,8 @@ def validate_resolved_ip(ip_str: str, policy: SSRFPolicy) -> None:
try:
addr = ipaddress.ip_address(ip_str)
except ValueError as exc:
raise SSRFBlockedError("invalid IP address") from exc
msg = "invalid IP address"
raise SSRFBlockedError(msg) from exc
if isinstance(addr, ipaddress.IPv6Address):
inner = _extract_embedded_ipv4(addr)
@@ -216,13 +220,16 @@ def validate_hostname(hostname: str, policy: SSRFPolicy) -> None:
lower = hostname.lower()
if policy.block_localhost and lower in _LOCALHOST_NAMES:
raise SSRFBlockedError("localhost address")
msg = "localhost address"
raise SSRFBlockedError(msg)
if policy.block_cloud_metadata and lower in _CLOUD_METADATA_HOSTNAMES:
raise SSRFBlockedError("cloud metadata endpoint")
msg = "cloud metadata endpoint"
raise SSRFBlockedError(msg)
if policy.block_k8s_internal and lower.endswith(_K8S_SUFFIX):
raise SSRFBlockedError("Kubernetes internal DNS")
msg = "Kubernetes internal DNS"
raise SSRFBlockedError(msg)
def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]:
@@ -235,7 +242,7 @@ def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]:
return policy.allowed_hosts
async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
async def validate_url(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None:
"""Validate a URL against the SSRF policy, including DNS resolution.
This is the primary entry-point for async code paths. It delegates
@@ -268,7 +275,7 @@ async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
validate_resolved_ip(str(sockaddr[0]), policy)
def validate_url_sync(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
def validate_url_sync(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None:
"""Synchronous URL validation (no DNS resolution).
Suitable for Pydantic validators and other sync contexts. Checks scheme

View File

@@ -7,6 +7,7 @@ import httpx
from langchain_core._security._exceptions import SSRFBlockedError
from langchain_core._security._policy import (
DEFAULT_SSRF_POLICY,
SSRFPolicy,
_effective_allowed_hosts,
validate_resolved_ip,
@@ -44,7 +45,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
def __init__(
self,
policy: SSRFPolicy = SSRFPolicy(),
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
**transport_kwargs: object,
) -> None:
self._policy = policy
@@ -62,10 +63,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
scheme = request.url.scheme.lower()
# 1-3. Scheme, hostname, and pattern checks (reuse sync validator).
try:
validate_url_sync(str(request.url), self._policy)
except SSRFBlockedError:
raise
validate_url_sync(str(request.url), self._policy)
# Allowed-hosts bypass - skip DNS/IP validation entirely.
allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}
@@ -82,10 +80,12 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
type=socket.SOCK_STREAM,
)
except socket.gaierror as exc:
raise SSRFBlockedError("DNS resolution failed") from exc
msg = "DNS resolution failed"
raise SSRFBlockedError(msg) from exc
if not addrinfo:
raise SSRFBlockedError("DNS resolution returned no results")
msg = "DNS resolution returned no results"
raise SSRFBlockedError(msg)
# 5. Validate ALL resolved IPs - any blocked means reject.
for _family, _type, _proto, _canonname, sockaddr in addrinfo:
@@ -135,7 +135,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
def __init__(
self,
policy: SSRFPolicy = SSRFPolicy(),
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
**transport_kwargs: object,
) -> None:
self._policy = policy
@@ -162,10 +162,12 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
type=socket.SOCK_STREAM,
)
except socket.gaierror as exc:
raise SSRFBlockedError("DNS resolution failed") from exc
msg = "DNS resolution failed"
raise SSRFBlockedError(msg) from exc
if not addrinfo:
raise SSRFBlockedError("DNS resolution returned no results")
msg = "DNS resolution returned no results"
raise SSRFBlockedError(msg)
for _family, _type, _proto, _canonname, sockaddr in addrinfo:
ip_str: str = sockaddr[0] # type: ignore[assignment]
@@ -198,7 +200,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
def ssrf_safe_client(
policy: SSRFPolicy = SSRFPolicy(),
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
**kwargs: object,
) -> httpx.Client:
"""Create an `httpx.Client` with SSRF protection."""
@@ -222,7 +224,7 @@ def ssrf_safe_client(
def ssrf_safe_async_client(
policy: SSRFPolicy = SSRFPolicy(),
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
**kwargs: object,
) -> httpx.AsyncClient:
"""Create an `httpx.AsyncClient` with SSRF protection.

View File

@@ -133,10 +133,8 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa
"langchain_core/utils/mustache.py" = [ "PLW0603",]
"langchain_core/sys_info.py" = [ "T201",]
"tests/unit_tests/test_tools.py" = [ "ARG",]
"tests/**" = [ "ARG", "D1", "PLR2004", "S", "SLF",]
"tests/**" = [ "D1", "PLR2004", "S", "SLF",]
"scripts/**" = [ "INP", "S", "T201",]
"langchain_core/_security/_policy.py" = [ "EM101", "EM102", "TRY003", "B008", "TRY300",]
"langchain_core/_security/_transport.py" = [ "EM101", "EM102", "TRY003", "TRY203", "B008",]
[tool.coverage.run]
omit = [ "tests/*",]

View File

@@ -2851,7 +2851,7 @@ async def test_tool_error_event_includes_tool_call_id() -> None:
"""Test that on_tool_error event includes tool_call_id when provided."""
@tool
def failing_tool(x: int) -> str:
def failing_tool(x: int) -> str: # noqa: ARG001
"""A tool that always fails."""
msg = "Tool execution failed"
raise ValueError(msg)
@@ -2891,7 +2891,7 @@ async def test_tool_error_event_tool_call_id_is_none_when_not_provided() -> None
"""Test that on_tool_error event has tool_call_id=None when not provided."""
@tool
def failing_tool_no_id(x: int) -> str:
def failing_tool_no_id(x: int) -> str: # noqa: ARG001
"""A tool that always fails."""
msg = "Tool execution failed"
raise ValueError(msg)

View File

@@ -26,7 +26,7 @@ def _fake_addrinfo_v6(ip: str, port: int = 80) -> list[Any]:
return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", (ip, port, 0, 0))]
def _ok_response(request: httpx.Request) -> httpx.Response:
def _ok_response(_request: httpx.Request) -> httpx.Response:
return httpx.Response(200, text="ok")
@@ -239,7 +239,7 @@ def test_link_local_range_blocked_as_cloud_metadata_when_private_ips_allowed(
async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None:
call_count = 0
def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]:
def _routing_addrinfo(*_args: Any, **_kwargs: Any) -> list[Any]:
nonlocal call_count
call_count += 1
if call_count == 1:
@@ -248,7 +248,7 @@ async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None:
monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo)
def _redirect_responder(request: httpx.Request) -> httpx.Response:
def _redirect_responder(_request: httpx.Request) -> httpx.Response:
return httpx.Response(
302,
headers={"Location": "http://evil.com/pwned"},
@@ -279,7 +279,7 @@ async def test_ipv6_mapped_ipv4_blocked(monkeypatch: Any) -> None:
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda *a, **kw: _fake_addrinfo_v6("::ffff:127.0.0.1"),
lambda *_, **__: _fake_addrinfo_v6("::ffff:127.0.0.1"),
)
transport = SSRFSafeTransport()
@@ -301,7 +301,7 @@ async def test_unresolvable_host_blocked(monkeypatch: Any) -> None:
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda *a, **kw: (_ for _ in ()).throw(
lambda *_, **__: (_ for _ in ()).throw(
socket.gaierror("Name or service not known")
),
)
@@ -388,7 +388,7 @@ def test_sync_transport_blocks_private_resolution() -> None:
def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None:
call_count = 0
def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]:
def _routing_addrinfo(*_args: Any, **__kwargs: Any) -> list[Any]:
nonlocal call_count
call_count += 1
if call_count == 1:
@@ -397,7 +397,7 @@ def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None:
monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo)
def _redirect_responder(request: httpx.Request) -> httpx.Response:
def _redirect_responder(_request: httpx.Request) -> httpx.Response:
return httpx.Response(
302,
headers={"Location": "http://evil.com/pwned"},