mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-03 19:04:23 +00:00
feat(core): harden anti-ssrf (#35960)
harden anti-ssrf --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -49,10 +49,21 @@ PRIVATE_IP_RANGES = [
|
||||
]
|
||||
|
||||
# Cloud provider metadata endpoints
|
||||
CLOUD_METADATA_RANGES = [
|
||||
ipaddress.ip_network(
|
||||
"169.254.0.0/16"
|
||||
), # IPv4 link-local (used by metadata services)
|
||||
]
|
||||
|
||||
CLOUD_METADATA_IPS = [
|
||||
"169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud
|
||||
"169.254.170.2", # AWS ECS task metadata
|
||||
"169.254.170.23", # AWS EKS Pod Identity Agent
|
||||
"100.100.100.200", # Alibaba Cloud metadata
|
||||
"fd00:ec2::254", # AWS EC2 IMDSv2 over IPv6 (Nitro instances)
|
||||
"fd00:ec2::23", # AWS EKS Pod Identity Agent (IPv6)
|
||||
"fe80::a9fe:a9fe", # OpenStack Nova metadata (IPv6 link-local equiv of
|
||||
# 169.254.169.254)
|
||||
]
|
||||
|
||||
CLOUD_METADATA_HOSTNAMES = [
|
||||
@@ -68,6 +79,21 @@ LOCALHOST_NAMES = [
|
||||
]
|
||||
|
||||
|
||||
def _normalize_ip(ip_str: str) -> str:
|
||||
"""Normalize IP strings for consistent SSRF checks.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string.
|
||||
|
||||
Returns:
|
||||
Canonical string form, converting IPv6-mapped IPv4 to plain IPv4.
|
||||
"""
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
|
||||
return str(ip.ipv4_mapped)
|
||||
return str(ip)
|
||||
|
||||
|
||||
def is_private_ip(ip_str: str) -> bool:
|
||||
"""Check if an IP address is in a private range.
|
||||
|
||||
@@ -78,7 +104,7 @@ def is_private_ip(ip_str: str) -> bool:
|
||||
True if IP is in a private range, False otherwise
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
ip = ipaddress.ip_address(_normalize_ip(ip_str))
|
||||
return any(ip in range_ for range_ in PRIVATE_IP_RANGES)
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -99,8 +125,17 @@ def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool:
|
||||
return True
|
||||
|
||||
# Check IP
|
||||
if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103
|
||||
return True
|
||||
if ip_str:
|
||||
try:
|
||||
normalized_ip = _normalize_ip(ip_str)
|
||||
if normalized_ip in CLOUD_METADATA_IPS:
|
||||
return True
|
||||
|
||||
ip = ipaddress.ip_address(normalized_ip)
|
||||
if any(ip in range_ for range_ in CLOUD_METADATA_RANGES):
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
@@ -122,12 +157,13 @@ def is_localhost(hostname: str, ip_str: str | None = None) -> bool:
|
||||
# Check IP
|
||||
if ip_str:
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
normalized_ip = _normalize_ip(ip_str)
|
||||
ip = ipaddress.ip_address(normalized_ip)
|
||||
# Check if loopback
|
||||
if ip.is_loopback:
|
||||
return True
|
||||
# Also check common localhost IPs
|
||||
if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
|
||||
if normalized_ip in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -225,20 +261,21 @@ def validate_safe_url(
|
||||
|
||||
for result in addr_info:
|
||||
ip_str: str = result[4][0] # type: ignore[assignment]
|
||||
normalized_ip = _normalize_ip(ip_str)
|
||||
|
||||
# ALWAYS block cloud metadata IPs
|
||||
if is_cloud_metadata(hostname, ip_str):
|
||||
msg = f"URL resolves to cloud metadata IP: {ip_str}"
|
||||
if is_cloud_metadata(hostname, normalized_ip):
|
||||
msg = f"URL resolves to cloud metadata IP: {normalized_ip}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for localhost IPs
|
||||
if is_localhost(hostname, ip_str) and not allow_private:
|
||||
msg = f"URL resolves to localhost IP: {ip_str}"
|
||||
if is_localhost(hostname, normalized_ip) and not allow_private:
|
||||
msg = f"URL resolves to localhost IP: {normalized_ip}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for private IPs
|
||||
if not allow_private and is_private_ip(ip_str):
|
||||
msg = f"URL resolves to private IP address: {ip_str}"
|
||||
if not allow_private and is_private_ip(normalized_ip):
|
||||
msg = f"URL resolves to private IP address: {normalized_ip}"
|
||||
raise ValueError(msg)
|
||||
|
||||
except socket.gaierror as e:
|
||||
|
||||
@@ -50,7 +50,16 @@ class TestIPValidation:
|
||||
"""Test cloud metadata IP detection."""
|
||||
assert is_cloud_metadata("example.com", "169.254.169.254") is True
|
||||
assert is_cloud_metadata("example.com", "169.254.170.2") is True
|
||||
assert is_cloud_metadata("example.com", "169.254.170.23") is True
|
||||
assert is_cloud_metadata("example.com", "100.100.100.200") is True
|
||||
assert is_cloud_metadata("example.com", "fd00:ec2::254") is True
|
||||
assert is_cloud_metadata("example.com", "fd00:ec2::23") is True
|
||||
assert is_cloud_metadata("example.com", "fe80::a9fe:a9fe") is True
|
||||
|
||||
def test_is_cloud_metadata_link_local_range(self) -> None:
|
||||
"""Test that IPv4 link-local is flagged as cloud metadata."""
|
||||
assert is_cloud_metadata("example.com", "169.254.1.2") is True
|
||||
assert is_cloud_metadata("example.com", "169.254.255.254") is True
|
||||
|
||||
def test_is_cloud_metadata_hostnames(self) -> None:
|
||||
"""Test cloud metadata hostname detection."""
|
||||
@@ -143,6 +152,16 @@ class TestValidateSafeUrl:
|
||||
allow_private=True,
|
||||
)
|
||||
|
||||
def test_ipv6_mapped_ipv4_localhost_blocked(self) -> None:
|
||||
"""Test that IPv6-mapped IPv4 localhost is blocked."""
|
||||
with pytest.raises(ValueError, match="localhost"):
|
||||
validate_safe_url("http://[::ffff:127.0.0.1]:8080/webhook")
|
||||
|
||||
def test_ipv6_mapped_ipv4_cloud_metadata_blocked(self) -> None:
|
||||
"""Test that IPv6-mapped IPv4 cloud metadata is blocked."""
|
||||
with pytest.raises(ValueError, match="metadata"):
|
||||
validate_safe_url("http://[::ffff:169.254.169.254]/latest/meta-data/")
|
||||
|
||||
def test_invalid_scheme_blocked(self) -> None:
|
||||
"""Test that non-HTTP(S) schemes are blocked."""
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
|
||||
Reference in New Issue
Block a user