mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
style(core): fix style in langchain_core/_security (#38189)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
0fdcdad393
commit
fc956c8680
@@ -114,6 +114,9 @@ class SSRFPolicy:
|
||||
] = ()
|
||||
|
||||
|
||||
DEFAULT_SSRF_POLICY = SSRFPolicy()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -196,7 +199,8 @@ def validate_resolved_ip(ip_str: str, policy: SSRFPolicy) -> None:
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError as exc:
|
||||
raise SSRFBlockedError("invalid IP address") from exc
|
||||
msg = "invalid IP address"
|
||||
raise SSRFBlockedError(msg) from exc
|
||||
|
||||
if isinstance(addr, ipaddress.IPv6Address):
|
||||
inner = _extract_embedded_ipv4(addr)
|
||||
@@ -216,13 +220,16 @@ def validate_hostname(hostname: str, policy: SSRFPolicy) -> None:
|
||||
lower = hostname.lower()
|
||||
|
||||
if policy.block_localhost and lower in _LOCALHOST_NAMES:
|
||||
raise SSRFBlockedError("localhost address")
|
||||
msg = "localhost address"
|
||||
raise SSRFBlockedError(msg)
|
||||
|
||||
if policy.block_cloud_metadata and lower in _CLOUD_METADATA_HOSTNAMES:
|
||||
raise SSRFBlockedError("cloud metadata endpoint")
|
||||
msg = "cloud metadata endpoint"
|
||||
raise SSRFBlockedError(msg)
|
||||
|
||||
if policy.block_k8s_internal and lower.endswith(_K8S_SUFFIX):
|
||||
raise SSRFBlockedError("Kubernetes internal DNS")
|
||||
msg = "Kubernetes internal DNS"
|
||||
raise SSRFBlockedError(msg)
|
||||
|
||||
|
||||
def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]:
|
||||
@@ -235,7 +242,7 @@ def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]:
|
||||
return policy.allowed_hosts
|
||||
|
||||
|
||||
async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
|
||||
async def validate_url(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None:
|
||||
"""Validate a URL against the SSRF policy, including DNS resolution.
|
||||
|
||||
This is the primary entry-point for async code paths. It delegates
|
||||
@@ -268,7 +275,7 @@ async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
|
||||
validate_resolved_ip(str(sockaddr[0]), policy)
|
||||
|
||||
|
||||
def validate_url_sync(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None:
|
||||
def validate_url_sync(url: str, policy: SSRFPolicy = DEFAULT_SSRF_POLICY) -> None:
|
||||
"""Synchronous URL validation (no DNS resolution).
|
||||
|
||||
Suitable for Pydantic validators and other sync contexts. Checks scheme
|
||||
|
||||
@@ -7,6 +7,7 @@ import httpx
|
||||
|
||||
from langchain_core._security._exceptions import SSRFBlockedError
|
||||
from langchain_core._security._policy import (
|
||||
DEFAULT_SSRF_POLICY,
|
||||
SSRFPolicy,
|
||||
_effective_allowed_hosts,
|
||||
validate_resolved_ip,
|
||||
@@ -44,7 +45,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SSRFPolicy = SSRFPolicy(),
|
||||
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
|
||||
**transport_kwargs: object,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
@@ -62,10 +63,7 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
|
||||
scheme = request.url.scheme.lower()
|
||||
|
||||
# 1-3. Scheme, hostname, and pattern checks (reuse sync validator).
|
||||
try:
|
||||
validate_url_sync(str(request.url), self._policy)
|
||||
except SSRFBlockedError:
|
||||
raise
|
||||
validate_url_sync(str(request.url), self._policy)
|
||||
|
||||
# Allowed-hosts bypass - skip DNS/IP validation entirely.
|
||||
allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)}
|
||||
@@ -82,10 +80,12 @@ class SSRFSafeTransport(httpx.AsyncBaseTransport):
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise SSRFBlockedError("DNS resolution failed") from exc
|
||||
msg = "DNS resolution failed"
|
||||
raise SSRFBlockedError(msg) from exc
|
||||
|
||||
if not addrinfo:
|
||||
raise SSRFBlockedError("DNS resolution returned no results")
|
||||
msg = "DNS resolution returned no results"
|
||||
raise SSRFBlockedError(msg)
|
||||
|
||||
# 5. Validate ALL resolved IPs - any blocked means reject.
|
||||
for _family, _type, _proto, _canonname, sockaddr in addrinfo:
|
||||
@@ -135,7 +135,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SSRFPolicy = SSRFPolicy(),
|
||||
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
|
||||
**transport_kwargs: object,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
@@ -162,10 +162,12 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise SSRFBlockedError("DNS resolution failed") from exc
|
||||
msg = "DNS resolution failed"
|
||||
raise SSRFBlockedError(msg) from exc
|
||||
|
||||
if not addrinfo:
|
||||
raise SSRFBlockedError("DNS resolution returned no results")
|
||||
msg = "DNS resolution returned no results"
|
||||
raise SSRFBlockedError(msg)
|
||||
|
||||
for _family, _type, _proto, _canonname, sockaddr in addrinfo:
|
||||
ip_str: str = sockaddr[0] # type: ignore[assignment]
|
||||
@@ -198,7 +200,7 @@ class SSRFSafeSyncTransport(httpx.BaseTransport):
|
||||
|
||||
|
||||
def ssrf_safe_client(
|
||||
policy: SSRFPolicy = SSRFPolicy(),
|
||||
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
|
||||
**kwargs: object,
|
||||
) -> httpx.Client:
|
||||
"""Create an `httpx.Client` with SSRF protection."""
|
||||
@@ -222,7 +224,7 @@ def ssrf_safe_client(
|
||||
|
||||
|
||||
def ssrf_safe_async_client(
|
||||
policy: SSRFPolicy = SSRFPolicy(),
|
||||
policy: SSRFPolicy = DEFAULT_SSRF_POLICY,
|
||||
**kwargs: object,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an `httpx.AsyncClient` with SSRF protection.
|
||||
|
||||
@@ -133,10 +133,8 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa
|
||||
"langchain_core/utils/mustache.py" = [ "PLW0603",]
|
||||
"langchain_core/sys_info.py" = [ "T201",]
|
||||
"tests/unit_tests/test_tools.py" = [ "ARG",]
|
||||
"tests/**" = [ "ARG", "D1", "PLR2004", "S", "SLF",]
|
||||
"tests/**" = [ "D1", "PLR2004", "S", "SLF",]
|
||||
"scripts/**" = [ "INP", "S", "T201",]
|
||||
"langchain_core/_security/_policy.py" = [ "EM101", "EM102", "TRY003", "B008", "TRY300",]
|
||||
"langchain_core/_security/_transport.py" = [ "EM101", "EM102", "TRY003", "TRY203", "B008",]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
||||
@@ -2851,7 +2851,7 @@ async def test_tool_error_event_includes_tool_call_id() -> None:
|
||||
"""Test that on_tool_error event includes tool_call_id when provided."""
|
||||
|
||||
@tool
|
||||
def failing_tool(x: int) -> str:
|
||||
def failing_tool(x: int) -> str: # noqa: ARG001
|
||||
"""A tool that always fails."""
|
||||
msg = "Tool execution failed"
|
||||
raise ValueError(msg)
|
||||
@@ -2891,7 +2891,7 @@ async def test_tool_error_event_tool_call_id_is_none_when_not_provided() -> None
|
||||
"""Test that on_tool_error event has tool_call_id=None when not provided."""
|
||||
|
||||
@tool
|
||||
def failing_tool_no_id(x: int) -> str:
|
||||
def failing_tool_no_id(x: int) -> str: # noqa: ARG001
|
||||
"""A tool that always fails."""
|
||||
msg = "Tool execution failed"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -26,7 +26,7 @@ def _fake_addrinfo_v6(ip: str, port: int = 80) -> list[Any]:
|
||||
return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", (ip, port, 0, 0))]
|
||||
|
||||
|
||||
def _ok_response(request: httpx.Request) -> httpx.Response:
|
||||
def _ok_response(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, text="ok")
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_link_local_range_blocked_as_cloud_metadata_when_private_ips_allowed(
|
||||
async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None:
|
||||
call_count = 0
|
||||
|
||||
def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]:
|
||||
def _routing_addrinfo(*_args: Any, **_kwargs: Any) -> list[Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
@@ -248,7 +248,7 @@ async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None:
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo)
|
||||
|
||||
def _redirect_responder(request: httpx.Request) -> httpx.Response:
|
||||
def _redirect_responder(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
302,
|
||||
headers={"Location": "http://evil.com/pwned"},
|
||||
@@ -279,7 +279,7 @@ async def test_ipv6_mapped_ipv4_blocked(monkeypatch: Any) -> None:
|
||||
monkeypatch.setattr(
|
||||
socket,
|
||||
"getaddrinfo",
|
||||
lambda *a, **kw: _fake_addrinfo_v6("::ffff:127.0.0.1"),
|
||||
lambda *_, **__: _fake_addrinfo_v6("::ffff:127.0.0.1"),
|
||||
)
|
||||
|
||||
transport = SSRFSafeTransport()
|
||||
@@ -301,7 +301,7 @@ async def test_unresolvable_host_blocked(monkeypatch: Any) -> None:
|
||||
monkeypatch.setattr(
|
||||
socket,
|
||||
"getaddrinfo",
|
||||
lambda *a, **kw: (_ for _ in ()).throw(
|
||||
lambda *_, **__: (_ for _ in ()).throw(
|
||||
socket.gaierror("Name or service not known")
|
||||
),
|
||||
)
|
||||
@@ -388,7 +388,7 @@ def test_sync_transport_blocks_private_resolution() -> None:
|
||||
def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None:
|
||||
call_count = 0
|
||||
|
||||
def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]:
|
||||
def _routing_addrinfo(*_args: Any, **__kwargs: Any) -> list[Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
@@ -397,7 +397,7 @@ def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None:
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo)
|
||||
|
||||
def _redirect_responder(request: httpx.Request) -> httpx.Response:
|
||||
def _redirect_responder(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
302,
|
||||
headers={"Location": "http://evil.com/pwned"},
|
||||
|
||||
Reference in New Issue
Block a user