feat(core): harden anti-ssrf (#35960)

harden anti-ssrf

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Eugene Yurtsev
2026-03-18 10:41:43 -04:00
committed by GitHub
parent cf1f510d77
commit dd136337d7
2 changed files with 67 additions and 11 deletions

View File

@@ -49,10 +49,21 @@ PRIVATE_IP_RANGES = [
]
# Cloud provider metadata endpoints
CLOUD_METADATA_RANGES = [
ipaddress.ip_network(
"169.254.0.0/16"
), # IPv4 link-local (used by metadata services)
]
CLOUD_METADATA_IPS = [
"169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud
"169.254.170.2", # AWS ECS task metadata
"169.254.170.23", # AWS EKS Pod Identity Agent
"100.100.100.200", # Alibaba Cloud metadata
"fd00:ec2::254", # AWS EC2 IMDSv2 over IPv6 (Nitro instances)
"fd00:ec2::23", # AWS EKS Pod Identity Agent (IPv6)
"fe80::a9fe:a9fe", # OpenStack Nova metadata (IPv6 link-local equiv of
# 169.254.169.254)
]
CLOUD_METADATA_HOSTNAMES = [
@@ -68,6 +79,21 @@ LOCALHOST_NAMES = [
]
def _normalize_ip(ip_str: str) -> str:
"""Normalize IP strings for consistent SSRF checks.
Args:
ip_str: IP address as a string.
Returns:
Canonical string form, converting IPv6-mapped IPv4 to plain IPv4.
"""
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
return str(ip.ipv4_mapped)
return str(ip)
def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is in a private range.
@@ -78,7 +104,7 @@ def is_private_ip(ip_str: str) -> bool:
True if IP is in a private range, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
ip = ipaddress.ip_address(_normalize_ip(ip_str))
return any(ip in range_ for range_ in PRIVATE_IP_RANGES)
except ValueError:
return False
@@ -99,8 +125,17 @@ def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool:
return True
# Check IP
if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103
return True
if ip_str:
try:
normalized_ip = _normalize_ip(ip_str)
if normalized_ip in CLOUD_METADATA_IPS:
return True
ip = ipaddress.ip_address(normalized_ip)
if any(ip in range_ for range_ in CLOUD_METADATA_RANGES):
return True
except ValueError:
pass
return False
@@ -122,12 +157,13 @@ def is_localhost(hostname: str, ip_str: str | None = None) -> bool:
# Check IP
if ip_str:
try:
ip = ipaddress.ip_address(ip_str)
normalized_ip = _normalize_ip(ip_str)
ip = ipaddress.ip_address(normalized_ip)
# Check if loopback
if ip.is_loopback:
return True
# Also check common localhost IPs
if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
if normalized_ip in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
return True
except ValueError:
pass
@@ -225,20 +261,21 @@ def validate_safe_url(
for result in addr_info:
ip_str: str = result[4][0] # type: ignore[assignment]
normalized_ip = _normalize_ip(ip_str)
# ALWAYS block cloud metadata IPs
if is_cloud_metadata(hostname, ip_str):
msg = f"URL resolves to cloud metadata IP: {ip_str}"
if is_cloud_metadata(hostname, normalized_ip):
msg = f"URL resolves to cloud metadata IP: {normalized_ip}"
raise ValueError(msg)
# Check for localhost IPs
if is_localhost(hostname, ip_str) and not allow_private:
msg = f"URL resolves to localhost IP: {ip_str}"
if is_localhost(hostname, normalized_ip) and not allow_private:
msg = f"URL resolves to localhost IP: {normalized_ip}"
raise ValueError(msg)
# Check for private IPs
if not allow_private and is_private_ip(ip_str):
msg = f"URL resolves to private IP address: {ip_str}"
if not allow_private and is_private_ip(normalized_ip):
msg = f"URL resolves to private IP address: {normalized_ip}"
raise ValueError(msg)
except socket.gaierror as e:

View File

@@ -50,7 +50,16 @@ class TestIPValidation:
"""Test cloud metadata IP detection."""
assert is_cloud_metadata("example.com", "169.254.169.254") is True
assert is_cloud_metadata("example.com", "169.254.170.2") is True
assert is_cloud_metadata("example.com", "169.254.170.23") is True
assert is_cloud_metadata("example.com", "100.100.100.200") is True
assert is_cloud_metadata("example.com", "fd00:ec2::254") is True
assert is_cloud_metadata("example.com", "fd00:ec2::23") is True
assert is_cloud_metadata("example.com", "fe80::a9fe:a9fe") is True
def test_is_cloud_metadata_link_local_range(self) -> None:
"""Test that IPv4 link-local is flagged as cloud metadata."""
assert is_cloud_metadata("example.com", "169.254.1.2") is True
assert is_cloud_metadata("example.com", "169.254.255.254") is True
def test_is_cloud_metadata_hostnames(self) -> None:
"""Test cloud metadata hostname detection."""
@@ -143,6 +152,16 @@ class TestValidateSafeUrl:
allow_private=True,
)
def test_ipv6_mapped_ipv4_localhost_blocked(self) -> None:
"""Test that IPv6-mapped IPv4 localhost is blocked."""
with pytest.raises(ValueError, match="localhost"):
validate_safe_url("http://[::ffff:127.0.0.1]:8080/webhook")
def test_ipv6_mapped_ipv4_cloud_metadata_blocked(self) -> None:
"""Test that IPv6-mapped IPv4 cloud metadata is blocked."""
with pytest.raises(ValueError, match="metadata"):
validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/")
def test_invalid_scheme_blocked(self) -> None:
"""Test that non-HTTP(S) schemes are blocked."""
with pytest.raises(ValueError, match="scheme"):