From fc956c86806dbcb52bf0e36da569ff7175017c76 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 16 Jun 2026 17:40:50 +0200 Subject: [PATCH] style(core): fix style in `langchain_core`/`_security` (#38189) Co-authored-by: Mason Daugherty --- libs/core/langchain_core/_security/_policy.py | 19 +++++++++----- .../langchain_core/_security/_transport.py | 26 ++++++++++--------- libs/core/pyproject.toml | 4 +-- .../runnables/test_runnable_events_v2.py | 4 +-- .../unit_tests/test_ssrf_policy_transport.py | 14 +++++----- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/libs/core/langchain_core/_security/_policy.py b/libs/core/langchain_core/_security/_policy.py index 9b4566cbe49..458101dbd85 100644 --- a/libs/core/langchain_core/_security/_policy.py +++ b/libs/core/langchain_core/_security/_policy.py @@ -114,6 +114,9 @@ class SSRFPolicy: ] = () +DEFAULT_SSRF_POLICY = SSRFPolicy() + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -196,7 +199,8 @@ def validate_resolved_ip(ip_str: str, policy: SSRFPolicy) -> None: try: addr = ipaddress.ip_address(ip_str) except ValueError as exc: - raise SSRFBlockedError("invalid IP address") from exc + msg = "invalid IP address" + raise SSRFBlockedError(msg) from exc if isinstance(addr, ipaddress.IPv6Address): inner = _extract_embedded_ipv4(addr) @@ -216,13 +220,16 @@ def validate_hostname(hostname: str, policy: SSRFPolicy) -> None: lower = hostname.lower() if policy.block_localhost and lower in _LOCALHOST_NAMES: - raise SSRFBlockedError("localhost address") + msg = "localhost address" + raise SSRFBlockedError(msg) if policy.block_cloud_metadata and lower in _CLOUD_METADATA_HOSTNAMES: - raise SSRFBlockedError("cloud metadata endpoint") + msg = "cloud metadata endpoint" + raise SSRFBlockedError(msg) if policy.block_k8s_internal and lower.endswith(_K8S_SUFFIX): - raise SSRFBlockedError("Kubernetes internal DNS") + msg = "Kubernetes internal DNS" + raise SSRFBlockedError(msg) def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]: @@ -235,7 +242,7 @@ def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]: return policy.allowed_hosts -async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None: +async def validate_url(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None: """Validate a URL against the SSRF policy, including DNS resolution. This is the primary entry-point for async code paths. It delegates @@ -268,7 +275,7 @@ async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None: validate_resolved_ip(str(sockaddr[0]), policy) -def validate_url_sync(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None: +def validate_url_sync(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None: """Synchronous URL validation (no DNS resolution). Suitable for Pydantic validators and other sync contexts. Checks scheme diff --git a/libs/core/langchain_core/_security/_transport.py b/libs/core/langchain_core/_security/_transport.py index 2bbc8d8989f..c1ea1baa393 100644 --- a/libs/core/langchain_core/_security/_transport.py +++ b/libs/core/langchain_core/_security/_transport.py @@ -7,6 +7,7 @@ import httpx from langchain_core._security._exceptions import SSRFBlockedError from langchain_core._security._policy import ( + DEFAULT_SSRF_POLICY, SSRFPolicy, _effective_allowed_hosts, validate_resolved_ip, @@ -44,7 +45,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport): def __init__( self, - policy: SSRFPolicy = SSRFPolicy(), + policy: SSRFPolicy = DEFAULT_SSRF_POLICY, **transport_kwargs: object, ) -> None: self._policy = policy @@ -62,10 +63,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport): scheme = request.url.scheme.lower() # 1-3. Scheme, hostname, and pattern checks (reuse sync validator). - try: - validate_url_sync(str(request.url), self._policy) - except SSRFBlockedError: - raise + validate_url_sync(str(request.url), self._policy) # Allowed-hosts bypass - skip DNS/IP validation entirely. allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)} @@ -82,10 +80,12 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport): type=socket.SOCK_STREAM, ) except socket.gaierror as exc: - raise SSRFBlockedError("DNS resolution failed") from exc + msg = "DNS resolution failed" + raise SSRFBlockedError(msg) from exc if not addrinfo: - raise SSRFBlockedError("DNS resolution returned no results") + msg = "DNS resolution returned no results" + raise SSRFBlockedError(msg) # 5. Validate ALL resolved IPs - any blocked means reject. for _family, _type, _proto, _canonname, sockaddr in addrinfo: @@ -135,7 +135,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport): def __init__( self, - policy: SSRFPolicy = SSRFPolicy(), + policy: SSRFPolicy = DEFAULT_SSRF_POLICY, **transport_kwargs: object, ) -> None: self._policy = policy @@ -162,10 +162,12 @@ class SSRFSafeSyncTransport(httpx.BaseTransport): type=socket.SOCK_STREAM, ) except socket.gaierror as exc: - raise SSRFBlockedError("DNS resolution failed") from exc + msg = "DNS resolution failed" + raise SSRFBlockedError(msg) from exc if not addrinfo: - raise SSRFBlockedError("DNS resolution returned no results") + msg = "DNS resolution returned no results" + raise SSRFBlockedError(msg) for _family, _type, _proto, _canonname, sockaddr in addrinfo: ip_str: str = sockaddr[0] # type: ignore[assignment] @@ -198,7 +200,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport): def ssrf_safe_client( - policy: SSRFPolicy = SSRFPolicy(), + policy: SSRFPolicy = DEFAULT_SSRF_POLICY, **kwargs: object, ) -> httpx.Client: """Create an `httpx.Client` with SSRF protection.""" @@ -222,7 +224,7 @@ def ssrf_safe_client( def ssrf_safe_async_client( - policy: SSRFPolicy = SSRFPolicy(), + policy: SSRFPolicy = DEFAULT_SSRF_POLICY, **kwargs: object, ) -> httpx.AsyncClient: """Create an `httpx.AsyncClient` with SSRF protection. diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 5a38f87fabb..39663253134 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -133,10 +133,8 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa "langchain_core/utils/mustache.py" = [ "PLW0603",] "langchain_core/sys_info.py" = [ "T201",] "tests/unit_tests/test_tools.py" = [ "ARG",] -"tests/**" = [ "ARG", "D1", "PLR2004", "S", "SLF",] +"tests/**" = [ "D1", "PLR2004", "S", "SLF",] "scripts/**" = [ "INP", "S", "T201",] -"langchain_core/_security/_policy.py" = [ "EM101", "EM102", "TRY003", "B008", "TRY300",] -"langchain_core/_security/_transport.py" = [ "EM101", "EM102", "TRY003", "TRY203", "B008",] [tool.coverage.run] omit = [ "tests/*",] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 0871406e777..8b50022f66a 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -2851,7 +2851,7 @@ async def test_tool_error_event_includes_tool_call_id() -> None: """Test that on_tool_error event includes tool_call_id when provided.""" @tool - def failing_tool(x: int) -> str: + def failing_tool(x: int) -> str: # noqa: ARG001 """A tool that always fails.""" msg = "Tool execution failed" raise ValueError(msg) @@ -2891,7 +2891,7 @@ async def test_tool_error_event_tool_call_id_is_none_when_not_provided() -> None """Test that on_tool_error event has tool_call_id=None when not provided.""" @tool - def failing_tool_no_id(x: int) -> str: + def failing_tool_no_id(x: int) -> str: # noqa: ARG001 """A tool that always fails.""" msg = "Tool execution failed" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/test_ssrf_policy_transport.py b/libs/core/tests/unit_tests/test_ssrf_policy_transport.py index fe174c88da9..3f182422b5b 100644 --- a/libs/core/tests/unit_tests/test_ssrf_policy_transport.py +++ b/libs/core/tests/unit_tests/test_ssrf_policy_transport.py @@ -26,7 +26,7 @@ def _fake_addrinfo_v6(ip: str, port: int = 80) -> list[Any]: return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", (ip, port, 0, 0))] -def _ok_response(request: httpx.Request) -> httpx.Response: +def _ok_response(_request: httpx.Request) -> httpx.Response: return httpx.Response(200, text="ok") @@ -239,7 +239,7 @@ def test_link_local_range_blocked_as_cloud_metadata_when_private_ips_allowed( async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None: call_count = 0 - def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]: + def _routing_addrinfo(*_args: Any, **_kwargs: Any) -> list[Any]: nonlocal call_count call_count += 1 if call_count == 1: @@ -248,7 +248,7 @@ async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None: monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo) - def _redirect_responder(request: httpx.Request) -> httpx.Response: + def _redirect_responder(_request: httpx.Request) -> httpx.Response: return httpx.Response( 302, headers={"Location": "http://evil.com/pwned"}, @@ -279,7 +279,7 @@ async def test_ipv6_mapped_ipv4_blocked(monkeypatch: Any) -> None: monkeypatch.setattr( socket, "getaddrinfo", - lambda *a, **kw: _fake_addrinfo_v6("::ffff:127.0.0.1"), + lambda *_, **__: _fake_addrinfo_v6("::ffff:127.0.0.1"), ) transport = SSRFSafeTransport() @@ -301,7 +301,7 @@ async def test_unresolvable_host_blocked(monkeypatch: Any) -> None: monkeypatch.setattr( socket, "getaddrinfo", - lambda *a, **kw: (_ for _ in ()).throw( + lambda *_, **__: (_ for _ in ()).throw( socket.gaierror("Name or service not known") ), ) @@ -388,7 +388,7 @@ def test_sync_transport_blocks_private_resolution() -> None: def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None: call_count = 0 - def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]: + def _routing_addrinfo(*_args: Any, **__kwargs: Any) -> list[Any]: nonlocal call_count call_count += 1 if call_count == 1: @@ -397,7 +397,7 @@ def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None: monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo) - def _redirect_responder(request: httpx.Request) -> httpx.Response: + def _redirect_responder(_request: httpx.Request) -> httpx.Response: return httpx.Response( 302, headers={"Location": "http://evil.com/pwned"},