From 7d601dc2c6fd58b5dc97b013ba93c287e09490db Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 15 Apr 2026 16:13:20 -0400 Subject: [PATCH] chore(core): harden private SSRF utilities (#36768) --- .../core/langchain_core/_security/__init__.py | 36 ++ .../langchain_core/_security/_exceptions.py | 9 + libs/core/langchain_core/_security/_policy.py | 290 +++++++++++++ .../_security/_ssrf_protection.py | 329 ++------------- .../langchain_core/_security/_transport.py | 252 ++++++++++++ libs/core/pyproject.toml | 4 +- .../runnables/test_runnable_events_v2.py | 4 +- .../unit_tests/test_ssrf_policy_transport.py | 387 ++++++++++++++++++ .../tests/unit_tests/test_ssrf_protection.py | 92 +---- 9 files changed, 1029 insertions(+), 374 deletions(-) create mode 100644 libs/core/langchain_core/_security/_exceptions.py create mode 100644 libs/core/langchain_core/_security/_policy.py create mode 100644 libs/core/langchain_core/_security/_transport.py create mode 100644 libs/core/tests/unit_tests/test_ssrf_policy_transport.py diff --git a/libs/core/langchain_core/_security/__init__.py b/libs/core/langchain_core/_security/__init__.py index e69de29bb2d..4a6fd41939f 100644 --- a/libs/core/langchain_core/_security/__init__.py +++ b/libs/core/langchain_core/_security/__init__.py @@ -0,0 +1,36 @@ +"""SSRF protection and security utilities. + +This is an **internal** module (note the `_security` prefix). It is NOT part of +the public `langchain-core` API and may change or be removed at any time without +notice. External code should not import from or depend on anything in this +module. Any vulnerability reports should target the public APIs that use these +utilities, not this internal module directly. +""" + +from langchain_core._security._exceptions import SSRFBlockedError +from langchain_core._security._policy import ( + SSRFPolicy, + validate_hostname, + validate_resolved_ip, + validate_url, + validate_url_sync, +) +from langchain_core._security._transport import ( + SSRFSafeSyncTransport, + SSRFSafeTransport, + ssrf_safe_async_client, + ssrf_safe_client, +) + +__all__ = [ + "SSRFBlockedError", + "SSRFPolicy", + "SSRFSafeSyncTransport", + "SSRFSafeTransport", + "ssrf_safe_async_client", + "ssrf_safe_client", + "validate_hostname", + "validate_resolved_ip", + "validate_url", + "validate_url_sync", +] diff --git a/libs/core/langchain_core/_security/_exceptions.py b/libs/core/langchain_core/_security/_exceptions.py new file mode 100644 index 00000000000..6046a4829a8 --- /dev/null +++ b/libs/core/langchain_core/_security/_exceptions.py @@ -0,0 +1,9 @@ +"""SSRF protection exceptions.""" + + +class SSRFBlockedError(Exception): + """Raised when a request is blocked by SSRF protection policy.""" + + def __init__(self, reason: str) -> None: + self.reason = reason + super().__init__(f"SSRF blocked: {reason}") diff --git a/libs/core/langchain_core/_security/_policy.py b/libs/core/langchain_core/_security/_policy.py new file mode 100644 index 00000000000..0440918bbc5 --- /dev/null +++ b/libs/core/langchain_core/_security/_policy.py @@ -0,0 +1,290 @@ +"""SSRF protection policy with IP validation and DNS-aware URL checking.""" + +import asyncio +import dataclasses +import ipaddress +import os +import socket +import urllib.parse + +from langchain_core._security._exceptions import SSRFBlockedError + +# --------------------------------------------------------------------------- +# Blocklist constants +# --------------------------------------------------------------------------- + +_BLOCKED_IPV4_NETWORKS: tuple[ipaddress.IPv4Network, ...] = tuple( + ipaddress.IPv4Network(n) + for n in ( + "10.0.0.0/8", # RFC 1918 - private class A + "172.16.0.0/12", # RFC 1918 - private class B + "192.168.0.0/16", # RFC 1918 - private class C + "127.0.0.0/8", # RFC 1122 - loopback + "169.254.0.0/16", # RFC 3927 - link-local + "0.0.0.0/8", # RFC 1122 - "this network" + "100.64.0.0/10", # RFC 6598 - shared/CGN address space + "192.0.0.0/24", # RFC 6890 - IETF protocol assignments + "192.0.2.0/24", # RFC 5737 - TEST-NET-1 (documentation) + "198.18.0.0/15", # RFC 2544 - benchmarking + "198.51.100.0/24", # RFC 5737 - TEST-NET-2 (documentation) + "203.0.113.0/24", # RFC 5737 - TEST-NET-3 (documentation) + "224.0.0.0/4", # RFC 5771 - multicast + "240.0.0.0/4", # RFC 1112 - reserved for future use + "255.255.255.255/32", # RFC 919 - limited broadcast + ) +) + +_BLOCKED_IPV6_NETWORKS: tuple[ipaddress.IPv6Network, ...] = tuple( + ipaddress.IPv6Network(n) + for n in ( + "::1/128", # RFC 4291 - loopback + "fc00::/7", # RFC 4193 - unique local addresses (ULA) + "fe80::/10", # RFC 4291 - link-local + "ff00::/8", # RFC 4291 - multicast + "::ffff:0:0/96", # RFC 4291 - IPv4-mapped IPv6 addresses + "::0.0.0.0/96", # RFC 4291 - IPv4-compatible IPv6 (deprecated) + "64:ff9b::/96", # RFC 6052 - NAT64 well-known prefix + "64:ff9b:1::/48", # RFC 8215 - NAT64 discovery prefix + ) +) + +_CLOUD_METADATA_IPS: frozenset[str] = frozenset( + { + "169.254.169.254", + "169.254.170.2", + "100.100.100.200", + "fd00:ec2::254", + } +) + +_CLOUD_METADATA_HOSTNAMES: frozenset[str] = frozenset( + { + "metadata.google.internal", + "metadata.amazonaws.com", + "metadata", + "instance-data", + } +) + +_LOCALHOST_NAMES: frozenset[str] = frozenset( + { + "localhost", + "localhost.localdomain", + "host.docker.internal", + } +) + +_K8S_SUFFIX = ".svc.cluster.local" + +_LOOPBACK_IPV4 = ipaddress.IPv4Network("127.0.0.0/8") +_LOOPBACK_IPV6 = ipaddress.IPv6Address("::1") + +# NAT64 well-known prefixes +_NAT64_PREFIX = ipaddress.IPv6Network("64:ff9b::/96") +_NAT64_DISCOVERY_PREFIX = ipaddress.IPv6Network("64:ff9b:1::/48") + + +# --------------------------------------------------------------------------- +# SSRFPolicy +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class SSRFPolicy: + """Immutable policy controlling which URLs/IPs are considered safe.""" + + allowed_schemes: frozenset[str] = frozenset({"http", "https"}) + block_private_ips: bool = True + block_localhost: bool = True + block_cloud_metadata: bool = True + block_k8s_internal: bool = True + allowed_hosts: frozenset[str] = frozenset() + additional_blocked_cidrs: tuple[ + ipaddress.IPv4Network | ipaddress.IPv6Network, ... + ] = () + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _extract_embedded_ipv4( + addr: ipaddress.IPv6Address, +) -> ipaddress.IPv4Address | None: + """Extract an embedded IPv4 from IPv4-mapped or NAT64 IPv6 addresses.""" + # Check ipv4_mapped first (covers ::ffff:x.x.x.x) + if addr.ipv4_mapped is not None: + return addr.ipv4_mapped + + # Check NAT64 prefixes — embedded IPv4 is in the last 4 bytes + if addr in _NAT64_PREFIX or addr in _NAT64_DISCOVERY_PREFIX: + raw = addr.packed + return ipaddress.IPv4Address(raw[-4:]) + + return None + + +def _ip_in_blocked_networks( + addr: ipaddress.IPv4Address | ipaddress.IPv6Address, + policy: SSRFPolicy, +) -> str | None: + """Return a reason string if *addr* falls in a blocked range, else None.""" + # NOTE: if profiling shows this is a hot path, consider memoising with + # @functools.lru_cache (key on (addr, id(policy))). + if isinstance(addr, ipaddress.IPv4Address): + if policy.block_private_ips: + for net in _BLOCKED_IPV4_NETWORKS: + if addr in net: + return "private IP range" + for net in policy.additional_blocked_cidrs: # type: ignore[assignment] + if isinstance(net, ipaddress.IPv4Network) and addr in net: + return "blocked CIDR" + else: + if policy.block_private_ips: + for net in _BLOCKED_IPV6_NETWORKS: # type: ignore[assignment] + if addr in net: + return "private IP range" + for net in policy.additional_blocked_cidrs: # type: ignore[assignment] + if isinstance(net, ipaddress.IPv6Network) and addr in net: + return "blocked CIDR" + + # Loopback check — independent of block_private_ips so that + # block_localhost=True still catches 127.x.x.x / ::1 even when + # private IPs are allowed. + if policy.block_localhost: + if isinstance(addr, ipaddress.IPv4Address) and ( + addr in _LOOPBACK_IPV4 or addr in ipaddress.IPv4Network("0.0.0.0/8") + ): + return "localhost address" + if isinstance(addr, ipaddress.IPv6Address) and addr == _LOOPBACK_IPV6: + return "localhost address" + + # Cloud metadata IP check + if policy.block_cloud_metadata and str(addr) in _CLOUD_METADATA_IPS: + return "cloud metadata endpoint" + + return None + + +# --------------------------------------------------------------------------- +# Public validation functions +# --------------------------------------------------------------------------- + + +def validate_resolved_ip(ip_str: str, policy: SSRFPolicy) -> None: + """Validate a resolved IP address against the SSRF policy. + + Raises SSRFBlockedError if the IP is blocked. + """ + try: + addr = ipaddress.ip_address(ip_str) + except ValueError as exc: + raise SSRFBlockedError("invalid IP address") from exc + + if isinstance(addr, ipaddress.IPv6Address): + inner = _extract_embedded_ipv4(addr) + if inner is not None: + addr = inner + + reason = _ip_in_blocked_networks(addr, policy) + if reason is not None: + raise SSRFBlockedError(reason) + + +def validate_hostname(hostname: str, policy: SSRFPolicy) -> None: + """Validate a hostname against the SSRF policy. + + Raises SSRFBlockedError if the hostname is blocked. + """ + lower = hostname.lower() + + if policy.block_localhost and lower in _LOCALHOST_NAMES: + raise SSRFBlockedError("localhost address") + + if policy.block_cloud_metadata and lower in _CLOUD_METADATA_HOSTNAMES: + raise SSRFBlockedError("cloud metadata endpoint") + + if policy.block_k8s_internal and lower.endswith(_K8S_SUFFIX): + raise SSRFBlockedError("Kubernetes internal DNS") + + +def _effective_allowed_hosts(policy: SSRFPolicy) -> frozenset[str]: + """Return allowed_hosts, augmented for local environments.""" + extra: set[str] = set() + if os.environ.get("LANGCHAIN_ENV", "").startswith("local"): + extra.update({"localhost", "testserver"}) + if extra: + return policy.allowed_hosts | frozenset(extra) + return policy.allowed_hosts + + +async def validate_url(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None: + """Validate a URL against the SSRF policy, including DNS resolution. + + This is the primary entry-point for async code paths. It delegates + scheme/hostname/allowed-hosts checks to ``validate_url_sync``, then + resolves DNS and validates every resolved IP. + + Raises: + SSRFBlockedError: If the URL violates the policy. + """ + parsed = urllib.parse.urlparse(url) + hostname = parsed.hostname or "" + + validate_url_sync(url, policy) + + allowed = {h.lower() for h in _effective_allowed_hosts(policy)} + if hostname.lower() in allowed: + return + + scheme = (parsed.scheme or "").lower() + port = parsed.port or (443 if scheme == "https" else 80) + try: + addrinfo = await asyncio.to_thread( + socket.getaddrinfo, hostname, port, type=socket.SOCK_STREAM + ) + except socket.gaierror as exc: + msg = "DNS resolution failed" + raise SSRFBlockedError(msg) from exc + + for _family, _type, _proto, _canonname, sockaddr in addrinfo: + validate_resolved_ip(str(sockaddr[0]), policy) + + +def validate_url_sync(url: str, policy: SSRFPolicy = SSRFPolicy()) -> None: + """Synchronous URL validation (no DNS resolution). + + Suitable for Pydantic validators and other sync contexts. Checks scheme + and hostname patterns only - use ``validate_url`` for full DNS-aware checking. + + Raises: + SSRFBlockedError: If the URL violates the policy. + """ + parsed = urllib.parse.urlparse(url) + + scheme = (parsed.scheme or "").lower() + if scheme not in policy.allowed_schemes: + msg = f"scheme '{scheme}' not allowed" + raise SSRFBlockedError(msg) + + hostname = parsed.hostname + if not hostname: + msg = "missing hostname" + raise SSRFBlockedError(msg) + + allowed = _effective_allowed_hosts(policy) + if hostname.lower() in {h.lower() for h in allowed}: + return + + try: + ipaddress.ip_address(hostname) + validate_resolved_ip(hostname, policy) + except SSRFBlockedError: + raise + except ValueError: + pass + else: + return + + validate_hostname(hostname, policy) diff --git a/libs/core/langchain_core/_security/_ssrf_protection.py b/libs/core/langchain_core/_security/_ssrf_protection.py index 5c3df6994c9..e1f7b482587 100644 --- a/libs/core/langchain_core/_security/_ssrf_protection.py +++ b/libs/core/langchain_core/_security/_ssrf_protection.py @@ -1,28 +1,8 @@ -"""SSRF Protection for validating URLs against Server-Side Request Forgery attacks. +"""SSRF Protection - thin wrapper raising ValueError for internal callers. -This module provides utilities to validate user-provided URLs and prevent SSRF attacks -by blocking requests to: -- Private IP ranges (RFC 1918, loopback, link-local) -- Cloud metadata endpoints (AWS, GCP, Azure, etc.) -- Localhost addresses -- Invalid URL schemes - -Usage: - from lc_security.ssrf_protection import validate_safe_url, is_safe_url - - # Validate a URL (raises ValueError if unsafe) - safe_url = validate_safe_url("https://example.com/webhook") - - # Check if URL is safe (returns bool) - if is_safe_url("http://192.168.1.1"): - # URL is safe - pass - - # Allow private IPs for development/testing (still blocks cloud metadata) - safe_url = validate_safe_url("http://localhost:8080", allow_private=True) +Delegates all validation to `langchain_core._security._policy`. """ -import ipaddress import os import socket from typing import Annotated, Any @@ -34,141 +14,28 @@ from pydantic import ( HttpUrl, ) -# Private IP ranges (RFC 1918, RFC 4193, RFC 3927, loopback) -PRIVATE_IP_RANGES = [ - ipaddress.ip_network("10.0.0.0/8"), # Private Class A - ipaddress.ip_network("172.16.0.0/12"), # Private Class B - ipaddress.ip_network("192.168.0.0/16"), # Private Class C - ipaddress.ip_network("127.0.0.0/8"), # Loopback - ipaddress.ip_network("169.254.0.0/16"), # Link-local (includes cloud metadata) - ipaddress.ip_network("0.0.0.0/8"), # Current network - ipaddress.ip_network("::1/128"), # IPv6 loopback - ipaddress.ip_network("fc00::/7"), # IPv6 unique local - ipaddress.ip_network("fe80::/10"), # IPv6 link-local - ipaddress.ip_network("ff00::/8"), # IPv6 multicast -] - -# Cloud provider metadata endpoints -CLOUD_METADATA_RANGES = [ - ipaddress.ip_network( - "169.254.0.0/16" - ), # IPv4 link-local (used by metadata services) -] - -CLOUD_METADATA_IPS = [ - "169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud - "169.254.170.2", # AWS ECS task metadata - "169.254.170.23", # AWS EKS Pod Identity Agent - "100.100.100.200", # Alibaba Cloud metadata - "fd00:ec2::254", # AWS EC2 IMDSv2 over IPv6 (Nitro instances) - "fd00:ec2::23", # AWS EKS Pod Identity Agent (IPv6) - "fe80::a9fe:a9fe", # OpenStack Nova metadata (IPv6 link-local equiv of - # 169.254.169.254) -] - -CLOUD_METADATA_HOSTNAMES = [ - "metadata.google.internal", # GCP - "metadata", # Generic - "instance-data", # AWS EC2 -] - -# Localhost variations -LOCALHOST_NAMES = [ - "localhost", - "localhost.localdomain", -] +from langchain_core._security._exceptions import SSRFBlockedError +from langchain_core._security._policy import ( + SSRFPolicy, +) +from langchain_core._security._policy import ( + validate_resolved_ip as _validate_resolved_ip, +) +from langchain_core._security._policy import ( + validate_url_sync as _validate_url_sync, +) -def _normalize_ip(ip_str: str) -> str: - """Normalize IP strings for consistent SSRF checks. - - Args: - ip_str: IP address as a string. - - Returns: - Canonical string form, converting IPv6-mapped IPv4 to plain IPv4. - """ - ip = ipaddress.ip_address(ip_str) - if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: - return str(ip.ipv4_mapped) - return str(ip) - - -def is_private_ip(ip_str: str) -> bool: - """Check if an IP address is in a private range. - - Args: - ip_str: IP address as a string (e.g., "192.168.1.1") - - Returns: - True if IP is in a private range, False otherwise - """ - try: - ip = ipaddress.ip_address(_normalize_ip(ip_str)) - return any(ip in range_ for range_ in PRIVATE_IP_RANGES) - except ValueError: - return False - - -def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool: - """Check if hostname or IP is a cloud metadata endpoint. - - Args: - hostname: Hostname to check - ip_str: Optional IP address to check - - Returns: - True if hostname or IP is a known cloud metadata endpoint - """ - # Check hostname - if hostname.lower() in CLOUD_METADATA_HOSTNAMES: - return True - - # Check IP - if ip_str: - try: - normalized_ip = _normalize_ip(ip_str) - if normalized_ip in CLOUD_METADATA_IPS: - return True - - ip = ipaddress.ip_address(normalized_ip) - if any(ip in range_ for range_ in CLOUD_METADATA_RANGES): - return True - except ValueError: - pass - - return False - - -def is_localhost(hostname: str, ip_str: str | None = None) -> bool: - """Check if hostname or IP is localhost. - - Args: - hostname: Hostname to check - ip_str: Optional IP address to check - - Returns: - True if hostname or IP is localhost - """ - # Check hostname - if hostname.lower() in LOCALHOST_NAMES: - return True - - # Check IP - if ip_str: - try: - normalized_ip = _normalize_ip(ip_str) - ip = ipaddress.ip_address(normalized_ip) - # Check if loopback - if ip.is_loopback: - return True - # Also check common localhost IPs - if normalized_ip in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104 - return True - except ValueError: - pass - - return False +def _policy_for(*, allow_private: bool, allow_http: bool) -> SSRFPolicy: + """Build an `SSRFPolicy` from the legacy flag interface.""" + schemes = frozenset({"http", "https"}) if allow_http else frozenset({"https"}) + return SSRFPolicy( + allowed_schemes=schemes, + block_private_ips=not allow_private, + block_localhost=not allow_private, + block_cloud_metadata=True, + block_k8s_internal=True, + ) def validate_safe_url( @@ -183,54 +50,22 @@ def validate_safe_url( by blocking requests to private networks and cloud metadata endpoints. Args: - url: The URL to validate (string or Pydantic HttpUrl) - allow_private: If True, allows private IPs and localhost (for development). + url: The URL to validate (string or Pydantic HttpUrl). + allow_private: If ``True``, allows private IPs and localhost (for development). Cloud metadata endpoints are ALWAYS blocked. - allow_http: If True, allows both HTTP and HTTPS. If False, only HTTPS. + allow_http: If ``True``, allows both HTTP and HTTPS. If ``False``, only HTTPS. Returns: - The validated URL as a string + The validated URL as a string. Raises: - ValueError: If URL is invalid or potentially dangerous - - Examples: - >>> validate_safe_url("https://hooks.slack.com/services/xxx") - 'https://hooks.slack.com/services/xxx' - - >>> validate_safe_url("http://127.0.0.1:8080") - ValueError: Localhost URLs are not allowed - - >>> validate_safe_url("http://192.168.1.1") - ValueError: URL resolves to private IP: 192.168.1.1 - - >>> validate_safe_url("http://169.254.169.254/latest/meta-data/") - ValueError: URL resolves to cloud metadata IP: 169.254.169.254 - - >>> validate_safe_url("http://localhost:8080", allow_private=True) - 'http://localhost:8080' + ValueError: If URL is invalid or potentially dangerous. """ url_str = str(url) parsed = urlparse(url_str) + hostname = parsed.hostname or "" - # Validate URL scheme - if not allow_http and parsed.scheme != "https": - msg = "Only HTTPS URLs are allowed" - raise ValueError(msg) - - if parsed.scheme not in ("http", "https"): - msg = f"Only HTTP/HTTPS URLs are allowed, got scheme: {parsed.scheme}" - raise ValueError(msg) - - # Extract hostname - hostname = parsed.hostname - if not hostname: - msg = "URL must have a valid hostname" - raise ValueError(msg) - - # Special handling for test environments - allow test server hostnames - # testserver is used by FastAPI/Starlette test clients and doesn't resolve via DNS - # Only enabled when LANGCHAIN_ENV=local_test (set in conftest.py) + # Test-environment bypass (preserved from original implementation) if ( os.environ.get("LANGCHAIN_ENV") == "local_test" and hostname.startswith("test") @@ -238,52 +73,34 @@ def validate_safe_url( ): return url_str - # ALWAYS block cloud metadata endpoints (even with allow_private=True) - if is_cloud_metadata(hostname): - msg = f"Cloud metadata endpoints are not allowed: {hostname}" - raise ValueError(msg) + policy = _policy_for(allow_private=allow_private, allow_http=allow_http) - # Check for localhost - if is_localhost(hostname) and not allow_private: - msg = f"Localhost URLs are not allowed: {hostname}" - raise ValueError(msg) - - # Resolve hostname to IP addresses and validate each one. - # Note: DNS resolution results are cached by the OS, so repeated calls are fast. + # Synchronous scheme + hostname checks + try: + _validate_url_sync(url_str, policy) + except SSRFBlockedError as exc: + raise ValueError(str(exc)) from exc + + # DNS resolution and IP validation try: - # Get all IP addresses for this hostname addr_info = socket.getaddrinfo( hostname, parsed.port or (443 if parsed.scheme == "https" else 80), - socket.AF_UNSPEC, # Allow both IPv4 and IPv6 + socket.AF_UNSPEC, socket.SOCK_STREAM, ) for result in addr_info: ip_str: str = result[4][0] # type: ignore[assignment] - normalized_ip = _normalize_ip(ip_str) - - # ALWAYS block cloud metadata IPs - if is_cloud_metadata(hostname, normalized_ip): - msg = f"URL resolves to cloud metadata IP: {normalized_ip}" - raise ValueError(msg) - - # Check for localhost IPs - if is_localhost(hostname, normalized_ip) and not allow_private: - msg = f"URL resolves to localhost IP: {normalized_ip}" - raise ValueError(msg) - - # Check for private IPs - if not allow_private and is_private_ip(normalized_ip): - msg = f"URL resolves to private IP address: {normalized_ip}" - raise ValueError(msg) + try: + _validate_resolved_ip(ip_str, policy) + except SSRFBlockedError as exc: + raise ValueError(str(exc)) from exc except socket.gaierror as e: - # DNS resolution failed - fail closed for security msg = f"Failed to resolve hostname '{hostname}': {e}" raise ValueError(msg) from e except OSError as e: - # Other network errors - fail closed msg = f"Network error while validating URL: {e}" raise ValueError(msg) from e @@ -296,26 +113,7 @@ def is_safe_url( allow_private: bool = False, allow_http: bool = True, ) -> bool: - """Check if a URL is safe (non-throwing version of validate_safe_url). - - Args: - url: The URL to check - allow_private: If True, allows private IPs and localhost - allow_http: If True, allows both HTTP and HTTPS - - Returns: - True if URL is safe, False otherwise - - Examples: - >>> is_safe_url("https://example.com") - True - - >>> is_safe_url("http://127.0.0.1:8080") - False - - >>> is_safe_url("http://localhost:8080", allow_private=True) - True - """ + """Non-throwing version of `validate_safe_url`.""" try: validate_safe_url(url, allow_private=allow_private, allow_http=allow_http) except ValueError: @@ -332,7 +130,6 @@ def _validate_url_ssrf_strict(v: Any) -> Any: def _validate_url_ssrf_https_only(v: Any) -> Any: - """Validate URL for SSRF protection (HTTPS only, strict mode).""" if isinstance(v, str): validate_safe_url(v, allow_private=False, allow_http=False) return v @@ -347,52 +144,12 @@ def _validate_url_ssrf_relaxed(v: Any) -> Any: # Annotated types with SSRF protection SSRFProtectedUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_strict)] -"""A Pydantic HttpUrl type with built-in SSRF protection. - -This blocks private IPs, localhost, and cloud metadata endpoints. - -Example: - class WebhookSchema(BaseModel): - url: SSRFProtectedUrl # Automatically validated for SSRF - headers: dict[str, str] | None = None -""" - SSRFProtectedUrlRelaxed = Annotated[ HttpUrl, BeforeValidator(_validate_url_ssrf_relaxed) ] -"""A Pydantic HttpUrl with relaxed SSRF protection (allows private IPs). - -Use this for development/testing webhooks where localhost/private IPs are needed. -Cloud metadata endpoints are still blocked. - -Example: - class DevWebhookSchema(BaseModel): - url: SSRFProtectedUrlRelaxed # Allows localhost, blocks cloud metadata -""" - SSRFProtectedHttpsUrl = Annotated[ HttpUrl, BeforeValidator(_validate_url_ssrf_https_only) ] -"""A Pydantic HttpUrl with SSRF protection that only allows HTTPS. - -This blocks private IPs, localhost, cloud metadata endpoints, and HTTP URLs. - -Example: - class SecureWebhookSchema(BaseModel): - url: SSRFProtectedHttpsUrl # Only HTTPS, blocks private IPs -""" - SSRFProtectedHttpsUrlStr = Annotated[ str, BeforeValidator(_validate_url_ssrf_https_only) ] -"""A string type with SSRF protection that only allows HTTPS URLs. - -Same as SSRFProtectedHttpsUrl but returns a string instead of HttpUrl. -Useful for FastAPI query parameters where you need a string URL. - -Example: - @router.get("/proxy") - async def proxy_get(url: SSRFProtectedHttpsUrlStr): - async with httpx.AsyncClient() as client: - resp = await client.get(url) -""" diff --git a/libs/core/langchain_core/_security/_transport.py b/libs/core/langchain_core/_security/_transport.py new file mode 100644 index 00000000000..b6efc9ec017 --- /dev/null +++ b/libs/core/langchain_core/_security/_transport.py @@ -0,0 +1,252 @@ +"""SSRF-safe httpx transport with DNS resolution and IP pinning.""" + +import asyncio +import socket + +import httpx + +from langchain_core._security._exceptions import SSRFBlockedError +from langchain_core._security._policy import ( + SSRFPolicy, + _effective_allowed_hosts, + validate_resolved_ip, + validate_url_sync, +) + +# Keys that AsyncHTTPTransport accepts (forwarded from factory kwargs). +_TRANSPORT_KWARGS = frozenset( + { + "verify", + "cert", + "trust_env", + "http1", + "http2", + "limits", + "retries", + } +) + + +class SSRFSafeTransport(httpx.AsyncBaseTransport): + """httpx async transport that validates DNS results against an SSRF policy. + + For every outgoing request the transport: + 1. Checks the URL scheme against ``policy.allowed_schemes``. + 2. Validates the hostname against blocked patterns. + 3. Resolves DNS and validates **all** returned IPs. + 4. Rewrites the request to connect to the first valid IP while + preserving the original ``Host`` header and TLS SNI hostname. + + Redirects are re-validated on each hop because ``follow_redirects`` + is set on the *client*, causing ``handle_async_request`` to be called + again for each redirect target. + """ + + def __init__( + self, + policy: SSRFPolicy = SSRFPolicy(), + **transport_kwargs: object, + ) -> None: + self._policy = policy + self._inner = httpx.AsyncHTTPTransport(**transport_kwargs) # type: ignore[arg-type] + + # ------------------------------------------------------------------ # + # Core request handler + # ------------------------------------------------------------------ # + + async def handle_async_request( + self, + request: httpx.Request, + ) -> httpx.Response: + hostname = request.url.host or "" + scheme = request.url.scheme.lower() + + # 1-3. Scheme, hostname, and pattern checks (reuse sync validator). + try: + validate_url_sync(str(request.url), self._policy) + except SSRFBlockedError: + raise + + # Allowed-hosts bypass - skip DNS/IP validation entirely. + allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)} + if hostname.lower() in allowed: + return await self._inner.handle_async_request(request) + + # 4. DNS resolution + port = request.url.port or (443 if scheme == "https" else 80) + try: + addrinfo = await asyncio.to_thread( + socket.getaddrinfo, + hostname, + port, + type=socket.SOCK_STREAM, + ) + except socket.gaierror as exc: + raise SSRFBlockedError("DNS resolution failed") from exc + + if not addrinfo: + raise SSRFBlockedError("DNS resolution returned no results") + + # 5. Validate ALL resolved IPs - any blocked means reject. + for _family, _type, _proto, _canonname, sockaddr in addrinfo: + ip_str: str = sockaddr[0] # type: ignore[assignment] + validate_resolved_ip(ip_str, self._policy) + + # 6. Pin to first resolved IP. + pinned_ip = addrinfo[0][4][0] + + # 7. Rewrite URL to use pinned IP, preserving Host header and SNI. + pinned_url = request.url.copy_with(host=pinned_ip) + + # Build extensions dict, adding sni_hostname for HTTPS so TLS + # certificate validation uses the original hostname. + extensions = dict(request.extensions) + if scheme == "https": + extensions["sni_hostname"] = hostname.encode("ascii") + + pinned_request = httpx.Request( + method=request.method, + url=pinned_url, + headers=request.headers, # Host header already set to original + content=request.content, + extensions=extensions, + ) + + return await self._inner.handle_async_request(pinned_request) + + # ------------------------------------------------------------------ # + # Lifecycle + # ------------------------------------------------------------------ # + + async def aclose(self) -> None: + await self._inner.aclose() + + +# ---------------------------------------------------------------------- # +# Factory +# ---------------------------------------------------------------------- # + + +class SSRFSafeSyncTransport(httpx.BaseTransport): + """httpx sync transport that validates DNS results against an SSRF policy. + + Sync mirror of `SSRFSafeTransport`. See that class for full documentation. + """ + + def __init__( + self, + policy: SSRFPolicy = SSRFPolicy(), + **transport_kwargs: object, + ) -> None: + self._policy = policy + self._inner = httpx.HTTPTransport(**transport_kwargs) # type: ignore[arg-type] + + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + hostname = request.url.host or "" + scheme = request.url.scheme.lower() + + validate_url_sync(str(request.url), self._policy) + + allowed = {h.lower() for h in _effective_allowed_hosts(self._policy)} + if hostname.lower() in allowed: + return self._inner.handle_request(request) + + port = request.url.port or (443 if scheme == "https" else 80) + try: + addrinfo = socket.getaddrinfo( + hostname, + port, + type=socket.SOCK_STREAM, + ) + except socket.gaierror as exc: + raise SSRFBlockedError("DNS resolution failed") from exc + + if not addrinfo: + raise SSRFBlockedError("DNS resolution returned no results") + + for _family, _type, _proto, _canonname, sockaddr in addrinfo: + ip_str: str = sockaddr[0] # type: ignore[assignment] + validate_resolved_ip(ip_str, self._policy) + + pinned_ip = addrinfo[0][4][0] + pinned_url = request.url.copy_with(host=pinned_ip) + + extensions = dict(request.extensions) + if scheme == "https": + extensions["sni_hostname"] = hostname.encode("ascii") + + pinned_request = httpx.Request( + method=request.method, + url=pinned_url, + headers=request.headers, + content=request.content, + extensions=extensions, + ) + + return self._inner.handle_request(pinned_request) + + def close(self) -> None: + self._inner.close() + + +# ---------------------------------------------------------------------- # +# Factories +# ---------------------------------------------------------------------- # + + +def ssrf_safe_client( + policy: SSRFPolicy = SSRFPolicy(), + **kwargs: object, +) -> httpx.Client: + """Create an `httpx.Client` with SSRF protection.""" + transport_kwargs: dict[str, object] = {} + client_kwargs: dict[str, object] = {} + for key, value in kwargs.items(): + if key in _TRANSPORT_KWARGS: + transport_kwargs[key] = value + else: + client_kwargs[key] = value + + transport = SSRFSafeSyncTransport(policy=policy, **transport_kwargs) + + client_kwargs.setdefault("follow_redirects", True) + client_kwargs.setdefault("max_redirects", 10) + + return httpx.Client( + transport=transport, + **client_kwargs, # type: ignore[arg-type] + ) + + +def ssrf_safe_async_client( + policy: SSRFPolicy = SSRFPolicy(), + **kwargs: object, +) -> httpx.AsyncClient: + """Create an ``httpx.AsyncClient`` with SSRF protection. + + Drop-in replacement for ``httpx.AsyncClient(...)`` - callers just swap + the constructor call. Transport-specific kwargs (``verify``, ``cert``, + ``retries``, etc.) are forwarded to the inner ``AsyncHTTPTransport``; + everything else goes to the ``AsyncClient``. + """ + transport_kwargs: dict[str, object] = {} + client_kwargs: dict[str, object] = {} + for key, value in kwargs.items(): + if key in _TRANSPORT_KWARGS: + transport_kwargs[key] = value + else: + client_kwargs[key] = value + + transport = SSRFSafeTransport(policy=policy, **transport_kwargs) + + # Apply defaults only if not overridden by caller. + client_kwargs.setdefault("follow_redirects", True) + client_kwargs.setdefault("max_redirects", 10) + + return httpx.AsyncClient( + transport=transport, + **client_kwargs, # type: ignore[arg-type] + ) diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 933d912fc65..38be808b783 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -135,8 +135,10 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa "langchain_core/utils/mustache.py" = [ "PLW0603",] "langchain_core/sys_info.py" = [ "T201",] "tests/unit_tests/test_tools.py" = [ "ARG",] -"tests/**" = [ "D1", "PLR2004", "S", "SLF",] +"tests/**" = [ "ARG", "D1", "PLR2004", "S", "SLF",] "scripts/**" = [ "INP", "S", "T201",] +"langchain_core/_security/_policy.py" = [ "EM101", "EM102", "TRY003", "B008", "TRY300",] +"langchain_core/_security/_transport.py" = [ "EM101", "EM102", "TRY003", "TRY203", "B008",] [tool.coverage.run] omit = [ "tests/*",] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 49342d2fc17..f36dc87cbd7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -2843,7 +2843,7 @@ async def test_tool_error_event_includes_tool_call_id() -> None: """Test that on_tool_error event includes tool_call_id when provided.""" @tool - def failing_tool(x: int) -> str: # noqa: ARG001 + def failing_tool(x: int) -> str: """A tool that always fails.""" msg = "Tool execution failed" raise ValueError(msg) @@ -2883,7 +2883,7 @@ async def test_tool_error_event_tool_call_id_is_none_when_not_provided() -> None """Test that on_tool_error event has tool_call_id=None when not provided.""" @tool - def failing_tool_no_id(x: int) -> str: # noqa: ARG001 + def failing_tool_no_id(x: int) -> str: """A tool that always fails.""" msg = "Tool execution failed" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/test_ssrf_policy_transport.py b/libs/core/tests/unit_tests/test_ssrf_policy_transport.py new file mode 100644 index 00000000000..2a905c1d2bf --- /dev/null +++ b/libs/core/tests/unit_tests/test_ssrf_policy_transport.py @@ -0,0 +1,387 @@ +import socket +from typing import Any +from unittest.mock import patch + +import httpx +import pytest + +from langchain_core._security import ( + SSRFBlockedError, + SSRFPolicy, + SSRFSafeSyncTransport, + SSRFSafeTransport, + ssrf_safe_async_client, + ssrf_safe_client, + validate_hostname, + validate_resolved_ip, + validate_url_sync, +) + + +def _fake_addrinfo(ip: str, port: int = 80) -> list[Any]: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))] + + +def _fake_addrinfo_v6(ip: str, port: int = 80) -> list[Any]: + return [(socket.AF_INET6, socket.SOCK_STREAM, 6, "", (ip, port, 0, 0))] + + +def _ok_response(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, text="ok") + + +def test_validate_resolved_ip_blocks_nat64_embedded_private_ip() -> None: + policy = SSRFPolicy() + + with pytest.raises(SSRFBlockedError, match="private IP range"): + validate_resolved_ip("64:ff9b::c0a8:101", policy) + + +def test_validate_resolved_ip_blocks_cgnat() -> None: + policy = SSRFPolicy() + + with pytest.raises(SSRFBlockedError, match="private IP range"): + validate_resolved_ip("100.64.0.1", policy) + + +def test_validate_hostname_blocks_kubernetes_internal_dns() -> None: + policy = SSRFPolicy() + + with pytest.raises(SSRFBlockedError, match="Kubernetes internal DNS"): + validate_hostname("api.default.svc.cluster.local", policy) + + +def test_validate_url_sync_allows_explicit_allowed_host() -> None: + policy = SSRFPolicy(allowed_hosts=frozenset({"metadata.google.internal"})) + + validate_url_sync("http://metadata.google.internal/path", policy) + + +def test_validate_url_sync_blocks_metadata_without_allowlist() -> None: + policy = SSRFPolicy() + + with pytest.raises(SSRFBlockedError, match="cloud metadata endpoint"): + validate_url_sync("http://metadata.google.internal/path", policy) + + +class _RecordingAsyncTransport(httpx.AsyncBaseTransport): + def __init__(self) -> None: + self.requests: list[httpx.Request] = [] + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + self.requests.append(request) + return httpx.Response(200, request=request, text="ok") + + async def aclose(self) -> None: + return None + + +@pytest.mark.asyncio +async def test_ssrf_safe_transport_pins_ip_and_sets_sni() -> None: + transport = SSRFSafeTransport() + recorder = _RecordingAsyncTransport() + transport._inner = recorder # type: ignore[assignment] + + addrinfo = [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 6, + "", + ("93.184.216.34", 443), + ) + ] + + with patch( + "langchain_core._security._transport.socket.getaddrinfo", + return_value=addrinfo, + ): + request = httpx.Request("GET", "https://example.com/resource") + response = await transport.handle_async_request(request) + + assert response.status_code == 200 + assert len(recorder.requests) == 1 + pinned_request = recorder.requests[0] + assert pinned_request.url.host == "93.184.216.34" + assert pinned_request.headers["host"] == "example.com" + assert pinned_request.extensions["sni_hostname"] == b"example.com" + + +@pytest.mark.asyncio +async def test_ssrf_safe_transport_blocks_private_resolution() -> None: + transport = SSRFSafeTransport() + + addrinfo = [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + 6, + "", + ("127.0.0.1", 443), + ) + ] + + with patch( + "langchain_core._security._transport.socket.getaddrinfo", + return_value=addrinfo, + ): + request = httpx.Request("GET", "https://example.com/resource") + with pytest.raises(SSRFBlockedError, match="private IP range"): + await transport.handle_async_request(request) + + +@pytest.mark.asyncio +async def test_ssrf_safe_async_client_sets_redirect_defaults() -> None: + client = ssrf_safe_async_client() + try: + assert client.follow_redirects is True + assert client.max_redirects == 10 + finally: + await client.aclose() + + +# --------------------------------------------------------------------------- +# Policy toggle: block_private_ips=False still blocks loopback/metadata/k8s +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "url", + [ + "http://10.0.0.1:8080/api", + "http://172.16.0.1:3000/", + "http://192.168.1.100/webhook", + ], +) +def test_private_ip_allowed_when_block_disabled(url: str) -> None: + policy = SSRFPolicy(block_private_ips=False) + validate_url_sync(url, policy) + + +@pytest.mark.parametrize( + "url", + [ + "http://127.0.0.1:8080/", + "http://127.0.0.2/", + "http://[::1]:8080/", + ], +) +def test_loopback_still_blocked_when_private_ips_allowed(url: str) -> None: + policy = SSRFPolicy(block_private_ips=False) + with pytest.raises(SSRFBlockedError): + validate_url_sync(url, policy) + + +def test_docker_internal_blocked() -> None: + policy = SSRFPolicy() + with pytest.raises(SSRFBlockedError, match="localhost"): + validate_url_sync("http://host.docker.internal:8080/", policy) + + +def test_metadata_still_blocked_when_private_ips_allowed() -> None: + policy = SSRFPolicy(block_private_ips=False) + with pytest.raises(SSRFBlockedError): + validate_url_sync("http://metadata.google.internal/", policy) + + +def test_k8s_still_blocked_when_private_ips_allowed() -> None: + policy = SSRFPolicy(block_private_ips=False) + with pytest.raises(SSRFBlockedError): + validate_url_sync("http://myservice.default.svc.cluster.local/", policy) + + +# --------------------------------------------------------------------------- +# Transport: redirect to private IP blocked +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_redirect_to_private_ip_blocked(monkeypatch: Any) -> None: + call_count = 0 + + def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return _fake_addrinfo("93.184.216.34") + return _fake_addrinfo("127.0.0.1") + + monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo) + + def _redirect_responder(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 302, + headers={"Location": "http://evil.com/pwned"}, + ) + + transport = SSRFSafeTransport() + transport._inner = httpx.MockTransport(_redirect_responder) # type: ignore[assignment] + + client = httpx.AsyncClient( + transport=transport, + follow_redirects=True, + max_redirects=5, + ) + + with pytest.raises(SSRFBlockedError): + await client.get("http://safe.com/start") + + await client.aclose() + + +# --------------------------------------------------------------------------- +# Transport: IPv6-mapped IPv4, scheme rejection, DNS fail-closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ipv6_mapped_ipv4_blocked(monkeypatch: Any) -> None: + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda *a, **kw: _fake_addrinfo_v6("::ffff:127.0.0.1"), + ) + + transport = SSRFSafeTransport() + request = httpx.Request("GET", "http://evil.com/") + with pytest.raises(SSRFBlockedError): + await transport.handle_async_request(request) + + +@pytest.mark.asyncio +async def test_scheme_blocked() -> None: + transport = SSRFSafeTransport() + request = httpx.Request("GET", "ftp://evil.com/file") + with pytest.raises(SSRFBlockedError, match="scheme"): + await transport.handle_async_request(request) + + +@pytest.mark.asyncio +async def test_unresolvable_host_blocked(monkeypatch: Any) -> None: + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda *a, **kw: (_ for _ in ()).throw( + socket.gaierror("Name or service not known") + ), + ) + + transport = SSRFSafeTransport() + request = httpx.Request("GET", "http://nonexistent.invalid/") + with pytest.raises(SSRFBlockedError, match="DNS resolution failed"): + await transport.handle_async_request(request) + + +# --------------------------------------------------------------------------- +# Transport: allowed_hosts bypass and local env behavior +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_allowed_host_bypass() -> None: + policy = SSRFPolicy(allowed_hosts=frozenset({"special.host"})) + transport = SSRFSafeTransport(policy=policy) + transport._inner = httpx.MockTransport(_ok_response) # type: ignore[assignment] + + request = httpx.Request("GET", "http://special.host/api") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("env", ["local_dev", "local_test", "local_docker"]) +async def test_localhost_allowed_in_local_env(monkeypatch: Any, env: str) -> None: + monkeypatch.setenv("LANGCHAIN_ENV", env) + transport = SSRFSafeTransport() + transport._inner = httpx.MockTransport(_ok_response) # type: ignore[assignment] + + request = httpx.Request("GET", "http://localhost:8084/mcp") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_localhost_blocked_in_production(monkeypatch: Any) -> None: + monkeypatch.setenv("LANGCHAIN_ENV", "production") + transport = SSRFSafeTransport() + + request = httpx.Request("GET", "http://localhost:8084/mcp") + with pytest.raises(SSRFBlockedError): + await transport.handle_async_request(request) + + +# --------------------------------------------------------------------------- +# Sync transport tests +# --------------------------------------------------------------------------- + + +def test_sync_transport_pins_ip_and_sets_sni() -> None: + transport = SSRFSafeSyncTransport() + transport._inner = httpx.MockTransport(_ok_response) # type: ignore[assignment] + + addrinfo = [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 443))] + + with patch( + "langchain_core._security._transport.socket.getaddrinfo", + return_value=addrinfo, + ): + request = httpx.Request("GET", "https://example.com/resource") + response = transport.handle_request(request) + + assert response.status_code == 200 + + +def test_sync_transport_blocks_private_resolution() -> None: + transport = SSRFSafeSyncTransport() + + addrinfo = [(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("127.0.0.1", 443))] + + with patch( + "langchain_core._security._transport.socket.getaddrinfo", + return_value=addrinfo, + ): + request = httpx.Request("GET", "https://example.com/resource") + with pytest.raises(SSRFBlockedError, match="private IP range"): + transport.handle_request(request) + + +def test_sync_transport_redirect_to_private_blocked(monkeypatch: Any) -> None: + call_count = 0 + + def _routing_addrinfo(*args: Any, **kwargs: Any) -> list[Any]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return _fake_addrinfo("93.184.216.34") + return _fake_addrinfo("127.0.0.1") + + monkeypatch.setattr(socket, "getaddrinfo", _routing_addrinfo) + + def _redirect_responder(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 302, + headers={"Location": "http://evil.com/pwned"}, + ) + + transport = SSRFSafeSyncTransport() + transport._inner = httpx.MockTransport(_redirect_responder) # type: ignore[assignment] + + client = httpx.Client( + transport=transport, + follow_redirects=True, + max_redirects=5, + ) + + with pytest.raises(SSRFBlockedError): + client.get("http://safe.com/start") + + client.close() + + +def test_ssrf_safe_client_sets_redirect_defaults() -> None: + client = ssrf_safe_client() + try: + assert client.follow_redirects is True + assert client.max_redirects == 10 + finally: + client.close() diff --git a/libs/core/tests/unit_tests/test_ssrf_protection.py b/libs/core/tests/unit_tests/test_ssrf_protection.py index ea0768025bb..6330034ff50 100644 --- a/libs/core/tests/unit_tests/test_ssrf_protection.py +++ b/libs/core/tests/unit_tests/test_ssrf_protection.py @@ -8,89 +8,11 @@ from pydantic import BaseModel, ValidationError from langchain_core._security._ssrf_protection import ( SSRFProtectedUrl, SSRFProtectedUrlRelaxed, - is_cloud_metadata, - is_localhost, - is_private_ip, is_safe_url, validate_safe_url, ) -class TestIPValidation: - """Tests for IP address validation functions.""" - - def test_is_private_ip_ipv4(self) -> None: - """Test private IPv4 address detection.""" - assert is_private_ip("10.0.0.1") is True - assert is_private_ip("172.16.0.1") is True - assert is_private_ip("192.168.1.1") is True - assert is_private_ip("127.0.0.1") is True - assert is_private_ip("169.254.169.254") is True - assert is_private_ip("0.0.0.1") is True - - def test_is_private_ip_ipv6(self) -> None: - """Test private IPv6 address detection.""" - assert is_private_ip("::1") is True # Loopback - assert is_private_ip("fc00::1") is True # Unique local - assert is_private_ip("fe80::1") is True # Link-local - assert is_private_ip("ff00::1") is True # Multicast - - def test_is_private_ip_public(self) -> None: - """Test that public IPs are not flagged as private.""" - assert is_private_ip("8.8.8.8") is False - assert is_private_ip("1.1.1.1") is False - assert is_private_ip("151.101.1.140") is False - - def test_is_private_ip_invalid(self) -> None: - """Test handling of invalid IP addresses.""" - assert is_private_ip("not-an-ip") is False - assert is_private_ip("999.999.999.999") is False - - def test_is_cloud_metadata_ips(self) -> None: - """Test cloud metadata IP detection.""" - assert is_cloud_metadata("example.com", "169.254.169.254") is True - assert is_cloud_metadata("example.com", "169.254.170.2") is True - assert is_cloud_metadata("example.com", "169.254.170.23") is True - assert is_cloud_metadata("example.com", "100.100.100.200") is True - assert is_cloud_metadata("example.com", "fd00:ec2::254") is True - assert is_cloud_metadata("example.com", "fd00:ec2::23") is True - assert is_cloud_metadata("example.com", "fe80::a9fe:a9fe") is True - - def test_is_cloud_metadata_link_local_range(self) -> None: - """Test that IPv4 link-local is flagged as cloud metadata.""" - assert is_cloud_metadata("example.com", "169.254.1.2") is True - assert is_cloud_metadata("example.com", "169.254.255.254") is True - - def test_is_cloud_metadata_hostnames(self) -> None: - """Test cloud metadata hostname detection.""" - assert is_cloud_metadata("metadata.google.internal") is True - assert is_cloud_metadata("metadata") is True - assert is_cloud_metadata("instance-data") is True - assert is_cloud_metadata("METADATA.GOOGLE.INTERNAL") is True # Case insensitive - - def test_is_cloud_metadata_safe(self) -> None: - """Test that normal URLs are not flagged as cloud metadata.""" - assert is_cloud_metadata("example.com", "8.8.8.8") is False - assert is_cloud_metadata("google.com") is False - - def test_is_localhost_hostnames(self) -> None: - """Test localhost hostname detection.""" - assert is_localhost("localhost") is True - assert is_localhost("LOCALHOST") is True - assert is_localhost("localhost.localdomain") is True - - def test_is_localhost_ips(self) -> None: - """Test localhost IP detection.""" - assert is_localhost("example.com", "127.0.0.1") is True - assert is_localhost("example.com", "::1") is True - assert is_localhost("example.com", "0.0.0.0") is True - - def test_is_localhost_safe(self) -> None: - """Test that normal hosts are not flagged as localhost.""" - assert is_localhost("example.com", "8.8.8.8") is False - assert is_localhost("google.com") is False - - class TestValidateSafeUrl: """Tests for validate_safe_url function.""" @@ -108,10 +30,10 @@ class TestValidateSafeUrl: def test_localhost_blocked_by_default(self) -> None: """Test that localhost URLs are blocked by default.""" - with pytest.raises(ValueError, match="Localhost"): + with pytest.raises(ValueError, match="localhost"): validate_safe_url("http://localhost:8080/webhook") - with pytest.raises(ValueError, match="localhost"): + with pytest.raises(ValueError, match="private IP"): validate_safe_url("http://127.0.0.1:8080/webhook") def test_localhost_allowed_with_flag(self) -> None: @@ -142,11 +64,11 @@ class TestValidateSafeUrl: def test_cloud_metadata_always_blocked(self) -> None: """Test that cloud metadata endpoints are always blocked.""" - with pytest.raises(ValueError, match="metadata"): + with pytest.raises(ValueError, match="SSRF blocked"): validate_safe_url("http://169.254.169.254/latest/meta-data/") # Even with allow_private=True - with pytest.raises(ValueError, match="metadata"): + with pytest.raises(ValueError, match="SSRF blocked"): validate_safe_url( "http://169.254.169.254/latest/meta-data/", allow_private=True, @@ -154,12 +76,12 @@ class TestValidateSafeUrl: def test_ipv6_mapped_ipv4_localhost_blocked(self) -> None: """Test that IPv6-mapped IPv4 localhost is blocked.""" - with pytest.raises(ValueError, match="localhost"): + with pytest.raises(ValueError, match="SSRF blocked"): validate_safe_url("http://[::ffff:127.0.0.1]:8080/webhook") def test_ipv6_mapped_ipv4_cloud_metadata_blocked(self) -> None: """Test that IPv6-mapped IPv4 cloud metadata is blocked.""" - with pytest.raises(ValueError, match="metadata"): + with pytest.raises(ValueError, match="SSRF blocked"): validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/") def test_invalid_scheme_blocked(self) -> None: @@ -175,7 +97,7 @@ class TestValidateSafeUrl: def test_https_only_mode(self) -> None: """Test that HTTP is blocked when allow_http=False.""" - with pytest.raises(ValueError, match="HTTPS"): + with pytest.raises(ValueError, match="scheme"): validate_safe_url("http://example.com/webhook", allow_http=False) # HTTPS should still work