mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
fix(openai): sanitize urls when counting tokens in images (#35143)
This commit is contained in:
0
libs/core/langchain_core/_security/__init__.py
Normal file
0
libs/core/langchain_core/_security/__init__.py
Normal file
361
libs/core/langchain_core/_security/_ssrf_protection.py
Normal file
361
libs/core/langchain_core/_security/_ssrf_protection.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""SSRF Protection for validating URLs against Server-Side Request Forgery attacks.
|
||||
|
||||
This module provides utilities to validate user-provided URLs and prevent SSRF attacks
|
||||
by blocking requests to:
|
||||
- Private IP ranges (RFC 1918, loopback, link-local)
|
||||
- Cloud metadata endpoints (AWS, GCP, Azure, etc.)
|
||||
- Localhost addresses
|
||||
- Invalid URL schemes
|
||||
|
||||
Usage:
|
||||
from lc_security.ssrf_protection import validate_safe_url, is_safe_url
|
||||
|
||||
# Validate a URL (raises ValueError if unsafe)
|
||||
safe_url = validate_safe_url("https://example.com/webhook")
|
||||
|
||||
# Check if URL is safe (returns bool)
|
||||
if is_safe_url("http://192.168.1.1"):
|
||||
# URL is safe
|
||||
pass
|
||||
|
||||
# Allow private IPs for development/testing (still blocks cloud metadata)
|
||||
safe_url = validate_safe_url("http://localhost:8080", allow_private=True)
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
from typing import Annotated, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import (
|
||||
AnyHttpUrl,
|
||||
BeforeValidator,
|
||||
HttpUrl,
|
||||
)
|
||||
|
||||
# Private IP ranges (RFC 1918, RFC 4193, RFC 3927, loopback)
|
||||
PRIVATE_IP_RANGES = [
|
||||
ipaddress.ip_network("10.0.0.0/8"), # Private Class A
|
||||
ipaddress.ip_network("172.16.0.0/12"), # Private Class B
|
||||
ipaddress.ip_network("192.168.0.0/16"), # Private Class C
|
||||
ipaddress.ip_network("127.0.0.0/8"), # Loopback
|
||||
ipaddress.ip_network("169.254.0.0/16"), # Link-local (includes cloud metadata)
|
||||
ipaddress.ip_network("0.0.0.0/8"), # Current network
|
||||
ipaddress.ip_network("::1/128"), # IPv6 loopback
|
||||
ipaddress.ip_network("fc00::/7"), # IPv6 unique local
|
||||
ipaddress.ip_network("fe80::/10"), # IPv6 link-local
|
||||
ipaddress.ip_network("ff00::/8"), # IPv6 multicast
|
||||
]
|
||||
|
||||
# Cloud provider metadata endpoints
|
||||
CLOUD_METADATA_IPS = [
|
||||
"169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud
|
||||
"169.254.170.2", # AWS ECS task metadata
|
||||
"100.100.100.200", # Alibaba Cloud metadata
|
||||
]
|
||||
|
||||
CLOUD_METADATA_HOSTNAMES = [
|
||||
"metadata.google.internal", # GCP
|
||||
"metadata", # Generic
|
||||
"instance-data", # AWS EC2
|
||||
]
|
||||
|
||||
# Localhost variations
|
||||
LOCALHOST_NAMES = [
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
]
|
||||
|
||||
|
||||
def is_private_ip(ip_str: str) -> bool:
|
||||
"""Check if an IP address is in a private range.
|
||||
|
||||
Args:
|
||||
ip_str: IP address as a string (e.g., "192.168.1.1")
|
||||
|
||||
Returns:
|
||||
True if IP is in a private range, False otherwise
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
return any(ip in range_ for range_ in PRIVATE_IP_RANGES)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool:
|
||||
"""Check if hostname or IP is a cloud metadata endpoint.
|
||||
|
||||
Args:
|
||||
hostname: Hostname to check
|
||||
ip_str: Optional IP address to check
|
||||
|
||||
Returns:
|
||||
True if hostname or IP is a known cloud metadata endpoint
|
||||
"""
|
||||
# Check hostname
|
||||
if hostname.lower() in CLOUD_METADATA_HOSTNAMES:
|
||||
return True
|
||||
|
||||
# Check IP
|
||||
if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_localhost(hostname: str, ip_str: str | None = None) -> bool:
|
||||
"""Check if hostname or IP is localhost.
|
||||
|
||||
Args:
|
||||
hostname: Hostname to check
|
||||
ip_str: Optional IP address to check
|
||||
|
||||
Returns:
|
||||
True if hostname or IP is localhost
|
||||
"""
|
||||
# Check hostname
|
||||
if hostname.lower() in LOCALHOST_NAMES:
|
||||
return True
|
||||
|
||||
# Check IP
|
||||
if ip_str:
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
# Check if loopback
|
||||
if ip.is_loopback:
|
||||
return True
|
||||
# Also check common localhost IPs
|
||||
if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_safe_url(
|
||||
url: str | AnyHttpUrl,
|
||||
*,
|
||||
allow_private: bool = False,
|
||||
allow_http: bool = True,
|
||||
) -> str:
|
||||
"""Validate a URL for SSRF protection.
|
||||
|
||||
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
|
||||
by blocking requests to private networks and cloud metadata endpoints.
|
||||
|
||||
Args:
|
||||
url: The URL to validate (string or Pydantic HttpUrl)
|
||||
allow_private: If True, allows private IPs and localhost (for development).
|
||||
Cloud metadata endpoints are ALWAYS blocked.
|
||||
allow_http: If True, allows both HTTP and HTTPS. If False, only HTTPS.
|
||||
|
||||
Returns:
|
||||
The validated URL as a string
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid or potentially dangerous
|
||||
|
||||
Examples:
|
||||
>>> validate_safe_url("https://hooks.slack.com/services/xxx")
|
||||
'https://hooks.slack.com/services/xxx'
|
||||
|
||||
>>> validate_safe_url("http://127.0.0.1:8080")
|
||||
ValueError: Localhost URLs are not allowed
|
||||
|
||||
>>> validate_safe_url("http://192.168.1.1")
|
||||
ValueError: URL resolves to private IP: 192.168.1.1
|
||||
|
||||
>>> validate_safe_url("http://169.254.169.254/latest/meta-data/")
|
||||
ValueError: URL resolves to cloud metadata IP: 169.254.169.254
|
||||
|
||||
>>> validate_safe_url("http://localhost:8080", allow_private=True)
|
||||
'http://localhost:8080'
|
||||
"""
|
||||
url_str = str(url)
|
||||
parsed = urlparse(url_str)
|
||||
|
||||
# Validate URL scheme
|
||||
if not allow_http and parsed.scheme != "https":
|
||||
msg = "Only HTTPS URLs are allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
msg = f"Only HTTP/HTTPS URLs are allowed, got scheme: {parsed.scheme}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Extract hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
msg = "URL must have a valid hostname"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Special handling for test environments - allow test server hostnames
|
||||
# testserver is used by FastAPI/Starlette test clients and doesn't resolve via DNS
|
||||
# Only enabled when LANGCHAIN_ENV=local_test (set in conftest.py)
|
||||
if (
|
||||
os.environ.get("LANGCHAIN_ENV") == "local_test"
|
||||
and hostname.startswith("test")
|
||||
and "server" in hostname
|
||||
):
|
||||
return url_str
|
||||
|
||||
# ALWAYS block cloud metadata endpoints (even with allow_private=True)
|
||||
if is_cloud_metadata(hostname):
|
||||
msg = f"Cloud metadata endpoints are not allowed: {hostname}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for localhost
|
||||
if is_localhost(hostname) and not allow_private:
|
||||
msg = f"Localhost URLs are not allowed: {hostname}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Resolve hostname to IP addresses and validate each one.
|
||||
# Note: DNS resolution results are cached by the OS, so repeated calls are fast.
|
||||
try:
|
||||
# Get all IP addresses for this hostname
|
||||
addr_info = socket.getaddrinfo(
|
||||
hostname,
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
socket.AF_UNSPEC, # Allow both IPv4 and IPv6
|
||||
socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
for result in addr_info:
|
||||
ip_str: str = result[4][0] # type: ignore[assignment]
|
||||
|
||||
# ALWAYS block cloud metadata IPs
|
||||
if is_cloud_metadata(hostname, ip_str):
|
||||
msg = f"URL resolves to cloud metadata IP: {ip_str}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for localhost IPs
|
||||
if is_localhost(hostname, ip_str) and not allow_private:
|
||||
msg = f"URL resolves to localhost IP: {ip_str}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for private IPs
|
||||
if not allow_private and is_private_ip(ip_str):
|
||||
msg = f"URL resolves to private IP address: {ip_str}"
|
||||
raise ValueError(msg)
|
||||
|
||||
except socket.gaierror as e:
|
||||
# DNS resolution failed - fail closed for security
|
||||
msg = f"Failed to resolve hostname '{hostname}': {e}"
|
||||
raise ValueError(msg) from e
|
||||
except OSError as e:
|
||||
# Other network errors - fail closed
|
||||
msg = f"Network error while validating URL: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
return url_str
|
||||
|
||||
|
||||
def is_safe_url(
|
||||
url: str | AnyHttpUrl,
|
||||
*,
|
||||
allow_private: bool = False,
|
||||
allow_http: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a URL is safe (non-throwing version of validate_safe_url).
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
allow_private: If True, allows private IPs and localhost
|
||||
allow_http: If True, allows both HTTP and HTTPS
|
||||
|
||||
Returns:
|
||||
True if URL is safe, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> is_safe_url("https://example.com")
|
||||
True
|
||||
|
||||
>>> is_safe_url("http://127.0.0.1:8080")
|
||||
False
|
||||
|
||||
>>> is_safe_url("http://localhost:8080", allow_private=True)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
validate_safe_url(url, allow_private=allow_private, allow_http=allow_http)
|
||||
except ValueError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _validate_url_ssrf_strict(v: Any) -> Any:
|
||||
"""Validate URL for SSRF protection (strict mode)."""
|
||||
if isinstance(v, str):
|
||||
validate_safe_url(v, allow_private=False, allow_http=True)
|
||||
return v
|
||||
|
||||
|
||||
def _validate_url_ssrf_https_only(v: Any) -> Any:
|
||||
"""Validate URL for SSRF protection (HTTPS only, strict mode)."""
|
||||
if isinstance(v, str):
|
||||
validate_safe_url(v, allow_private=False, allow_http=False)
|
||||
return v
|
||||
|
||||
|
||||
def _validate_url_ssrf_relaxed(v: Any) -> Any:
|
||||
"""Validate URL for SSRF protection (relaxed mode - allows private IPs)."""
|
||||
if isinstance(v, str):
|
||||
validate_safe_url(v, allow_private=True, allow_http=True)
|
||||
return v
|
||||
|
||||
|
||||
# Annotated types with SSRF protection
|
||||
SSRFProtectedUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_strict)]
|
||||
"""A Pydantic HttpUrl type with built-in SSRF protection.
|
||||
|
||||
This blocks private IPs, localhost, and cloud metadata endpoints.
|
||||
|
||||
Example:
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrl # Automatically validated for SSRF
|
||||
headers: dict[str, str] | None = None
|
||||
"""
|
||||
|
||||
SSRFProtectedUrlRelaxed = Annotated[
|
||||
HttpUrl, BeforeValidator(_validate_url_ssrf_relaxed)
|
||||
]
|
||||
"""A Pydantic HttpUrl with relaxed SSRF protection (allows private IPs).
|
||||
|
||||
Use this for development/testing webhooks where localhost/private IPs are needed.
|
||||
Cloud metadata endpoints are still blocked.
|
||||
|
||||
Example:
|
||||
class DevWebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrlRelaxed # Allows localhost, blocks cloud metadata
|
||||
"""
|
||||
|
||||
SSRFProtectedHttpsUrl = Annotated[
|
||||
HttpUrl, BeforeValidator(_validate_url_ssrf_https_only)
|
||||
]
|
||||
"""A Pydantic HttpUrl with SSRF protection that only allows HTTPS.
|
||||
|
||||
This blocks private IPs, localhost, cloud metadata endpoints, and HTTP URLs.
|
||||
|
||||
Example:
|
||||
class SecureWebhookSchema(BaseModel):
|
||||
url: SSRFProtectedHttpsUrl # Only HTTPS, blocks private IPs
|
||||
"""
|
||||
|
||||
SSRFProtectedHttpsUrlStr = Annotated[
|
||||
str, BeforeValidator(_validate_url_ssrf_https_only)
|
||||
]
|
||||
"""A string type with SSRF protection that only allows HTTPS URLs.
|
||||
|
||||
Same as SSRFProtectedHttpsUrl but returns a string instead of HttpUrl.
|
||||
Useful for FastAPI query parameters where you need a string URL.
|
||||
|
||||
Example:
|
||||
@router.get("/proxy")
|
||||
async def proxy_get(url: SSRFProtectedHttpsUrlStr):
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(url)
|
||||
"""
|
||||
297
libs/core/tests/unit_tests/test_ssrf_protection.py
Normal file
297
libs/core/tests/unit_tests/test_ssrf_protection.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""Tests for SSRF protection utilities."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langchain_core._security._ssrf_protection import (
|
||||
SSRFProtectedUrl,
|
||||
SSRFProtectedUrlRelaxed,
|
||||
is_cloud_metadata,
|
||||
is_localhost,
|
||||
is_private_ip,
|
||||
is_safe_url,
|
||||
validate_safe_url,
|
||||
)
|
||||
|
||||
|
||||
class TestIPValidation:
|
||||
"""Tests for IP address validation functions."""
|
||||
|
||||
def test_is_private_ip_ipv4(self) -> None:
|
||||
"""Test private IPv4 address detection."""
|
||||
assert is_private_ip("10.0.0.1") is True
|
||||
assert is_private_ip("172.16.0.1") is True
|
||||
assert is_private_ip("192.168.1.1") is True
|
||||
assert is_private_ip("127.0.0.1") is True
|
||||
assert is_private_ip("169.254.169.254") is True
|
||||
assert is_private_ip("0.0.0.1") is True
|
||||
|
||||
def test_is_private_ip_ipv6(self) -> None:
|
||||
"""Test private IPv6 address detection."""
|
||||
assert is_private_ip("::1") is True # Loopback
|
||||
assert is_private_ip("fc00::1") is True # Unique local
|
||||
assert is_private_ip("fe80::1") is True # Link-local
|
||||
assert is_private_ip("ff00::1") is True # Multicast
|
||||
|
||||
def test_is_private_ip_public(self) -> None:
|
||||
"""Test that public IPs are not flagged as private."""
|
||||
assert is_private_ip("8.8.8.8") is False
|
||||
assert is_private_ip("1.1.1.1") is False
|
||||
assert is_private_ip("151.101.1.140") is False
|
||||
|
||||
def test_is_private_ip_invalid(self) -> None:
|
||||
"""Test handling of invalid IP addresses."""
|
||||
assert is_private_ip("not-an-ip") is False
|
||||
assert is_private_ip("999.999.999.999") is False
|
||||
|
||||
def test_is_cloud_metadata_ips(self) -> None:
|
||||
"""Test cloud metadata IP detection."""
|
||||
assert is_cloud_metadata("example.com", "169.254.169.254") is True
|
||||
assert is_cloud_metadata("example.com", "169.254.170.2") is True
|
||||
assert is_cloud_metadata("example.com", "100.100.100.200") is True
|
||||
|
||||
def test_is_cloud_metadata_hostnames(self) -> None:
|
||||
"""Test cloud metadata hostname detection."""
|
||||
assert is_cloud_metadata("metadata.google.internal") is True
|
||||
assert is_cloud_metadata("metadata") is True
|
||||
assert is_cloud_metadata("instance-data") is True
|
||||
assert is_cloud_metadata("METADATA.GOOGLE.INTERNAL") is True # Case insensitive
|
||||
|
||||
def test_is_cloud_metadata_safe(self) -> None:
|
||||
"""Test that normal URLs are not flagged as cloud metadata."""
|
||||
assert is_cloud_metadata("example.com", "8.8.8.8") is False
|
||||
assert is_cloud_metadata("google.com") is False
|
||||
|
||||
def test_is_localhost_hostnames(self) -> None:
|
||||
"""Test localhost hostname detection."""
|
||||
assert is_localhost("localhost") is True
|
||||
assert is_localhost("LOCALHOST") is True
|
||||
assert is_localhost("localhost.localdomain") is True
|
||||
|
||||
def test_is_localhost_ips(self) -> None:
|
||||
"""Test localhost IP detection."""
|
||||
assert is_localhost("example.com", "127.0.0.1") is True
|
||||
assert is_localhost("example.com", "::1") is True
|
||||
assert is_localhost("example.com", "0.0.0.0") is True
|
||||
|
||||
def test_is_localhost_safe(self) -> None:
|
||||
"""Test that normal hosts are not flagged as localhost."""
|
||||
assert is_localhost("example.com", "8.8.8.8") is False
|
||||
assert is_localhost("google.com") is False
|
||||
|
||||
|
||||
class TestValidateSafeUrl:
|
||||
"""Tests for validate_safe_url function."""
|
||||
|
||||
def test_valid_public_https_url(self) -> None:
|
||||
"""Test that valid public HTTPS URLs are accepted."""
|
||||
url = "https://hooks.slack.com/services/xxx"
|
||||
result = validate_safe_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_valid_public_http_url(self) -> None:
|
||||
"""Test that valid public HTTP URLs are accepted."""
|
||||
url = "http://example.com/webhook"
|
||||
result = validate_safe_url(url)
|
||||
assert result == url
|
||||
|
||||
def test_localhost_blocked_by_default(self) -> None:
|
||||
"""Test that localhost URLs are blocked by default."""
|
||||
with pytest.raises(ValueError, match="Localhost"):
|
||||
validate_safe_url("http://localhost:8080/webhook")
|
||||
|
||||
with pytest.raises(ValueError, match="localhost"):
|
||||
validate_safe_url("http://127.0.0.1:8080/webhook")
|
||||
|
||||
def test_localhost_allowed_with_flag(self) -> None:
|
||||
"""Test that localhost is allowed with allow_private=True."""
|
||||
url = "http://localhost:8080/webhook"
|
||||
result = validate_safe_url(url, allow_private=True)
|
||||
assert result == url
|
||||
|
||||
url = "http://127.0.0.1:8080/webhook"
|
||||
result = validate_safe_url(url, allow_private=True)
|
||||
assert result == url
|
||||
|
||||
def test_private_ip_blocked_by_default(self) -> None:
|
||||
"""Test that private IPs are blocked by default."""
|
||||
with pytest.raises(ValueError, match="private IP"):
|
||||
validate_safe_url("http://192.168.1.1/webhook")
|
||||
|
||||
with pytest.raises(ValueError, match="private IP"):
|
||||
validate_safe_url("http://10.0.0.1/webhook")
|
||||
|
||||
with pytest.raises(ValueError, match="private IP"):
|
||||
validate_safe_url("http://172.16.0.1/webhook")
|
||||
|
||||
def test_private_ip_allowed_with_flag(self) -> None:
|
||||
"""Test that private IPs are allowed with allow_private=True."""
|
||||
# Note: These will fail DNS resolution in tests, so we skip actual validation
|
||||
# In production, they would be validated properly
|
||||
|
||||
def test_cloud_metadata_always_blocked(self) -> None:
|
||||
"""Test that cloud metadata endpoints are always blocked."""
|
||||
with pytest.raises(ValueError, match="metadata"):
|
||||
validate_safe_url("http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
# Even with allow_private=True
|
||||
with pytest.raises(ValueError, match="metadata"):
|
||||
validate_safe_url(
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
allow_private=True,
|
||||
)
|
||||
|
||||
def test_invalid_scheme_blocked(self) -> None:
|
||||
"""Test that non-HTTP(S) schemes are blocked."""
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
validate_safe_url("ftp://example.com/file")
|
||||
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
validate_safe_url("file:///etc/passwd")
|
||||
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
validate_safe_url("javascript:alert(1)")
|
||||
|
||||
def test_https_only_mode(self) -> None:
|
||||
"""Test that HTTP is blocked when allow_http=False."""
|
||||
with pytest.raises(ValueError, match="HTTPS"):
|
||||
validate_safe_url("http://example.com/webhook", allow_http=False)
|
||||
|
||||
# HTTPS should still work
|
||||
url = "https://example.com/webhook"
|
||||
result = validate_safe_url(url, allow_http=False)
|
||||
assert result == url
|
||||
|
||||
def test_url_without_hostname(self) -> None:
|
||||
"""Test that URLs without hostname are rejected."""
|
||||
with pytest.raises(ValueError, match="hostname"):
|
||||
validate_safe_url("http:///path")
|
||||
|
||||
def test_dns_resolution_failure(self) -> None:
|
||||
"""Test handling of DNS resolution failures."""
|
||||
with pytest.raises(ValueError, match="resolve"):
|
||||
validate_safe_url("http://this-domain-definitely-does-not-exist-12345.com")
|
||||
|
||||
def test_testserver_allowed(self, monkeypatch: Any) -> None:
|
||||
"""Test that testserver hostname is allowed for test environments."""
|
||||
# testserver is used by FastAPI/Starlette test clients
|
||||
monkeypatch.setenv("LANGCHAIN_ENV", "local_test")
|
||||
url = "http://testserver/webhook"
|
||||
result = validate_safe_url(url)
|
||||
assert result == url
|
||||
|
||||
|
||||
class TestIsSafeUrl:
|
||||
"""Tests for is_safe_url function (non-throwing version)."""
|
||||
|
||||
def test_safe_url_returns_true(self) -> None:
|
||||
"""Test that safe URLs return True."""
|
||||
assert is_safe_url("https://example.com/webhook") is True
|
||||
assert is_safe_url("http://hooks.slack.com/services/xxx") is True
|
||||
|
||||
def test_unsafe_url_returns_false(self) -> None:
|
||||
"""Test that unsafe URLs return False."""
|
||||
assert is_safe_url("http://localhost:8080") is False
|
||||
assert is_safe_url("http://127.0.0.1:8080") is False
|
||||
assert is_safe_url("http://192.168.1.1") is False
|
||||
assert is_safe_url("http://169.254.169.254") is False
|
||||
|
||||
def test_unsafe_url_safe_with_allow_private(self) -> None:
|
||||
"""Test that private URLs are safe with allow_private=True."""
|
||||
assert is_safe_url("http://localhost:8080", allow_private=True) is True
|
||||
assert is_safe_url("http://127.0.0.1:8080", allow_private=True) is True
|
||||
|
||||
def test_cloud_metadata_always_unsafe(self) -> None:
|
||||
"""Test that cloud metadata is always unsafe."""
|
||||
assert is_safe_url("http://169.254.169.254") is False
|
||||
assert is_safe_url("http://169.254.169.254", allow_private=True) is False
|
||||
|
||||
|
||||
class TestSSRFProtectedUrlType:
|
||||
"""Tests for SSRFProtectedUrl Pydantic type."""
|
||||
|
||||
def test_valid_url_accepted(self) -> None:
|
||||
"""Test that valid URLs are accepted by Pydantic schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrl
|
||||
|
||||
schema = WebhookSchema(url="https://hooks.slack.com/services/xxx")
|
||||
assert str(schema.url).startswith("https://hooks.slack.com/")
|
||||
|
||||
def test_localhost_rejected(self) -> None:
|
||||
"""Test that localhost URLs are rejected by Pydantic schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrl
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookSchema(url="http://localhost:8080")
|
||||
|
||||
def test_private_ip_rejected(self) -> None:
|
||||
"""Test that private IPs are rejected by Pydantic schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrl
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookSchema(url="http://192.168.1.1")
|
||||
|
||||
def test_cloud_metadata_rejected(self) -> None:
|
||||
"""Test that cloud metadata is rejected by Pydantic schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrl
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookSchema(url="http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
|
||||
class TestSSRFProtectedUrlRelaxedType:
|
||||
"""Tests for SSRFProtectedUrlRelaxed Pydantic type."""
|
||||
|
||||
def test_localhost_accepted(self) -> None:
|
||||
"""Test that localhost URLs are accepted by relaxed schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrlRelaxed
|
||||
|
||||
schema = WebhookSchema(url="http://localhost:8080")
|
||||
assert str(schema.url).startswith("http://localhost")
|
||||
|
||||
def test_cloud_metadata_still_rejected(self) -> None:
|
||||
"""Test that cloud metadata is still rejected by relaxed schema."""
|
||||
|
||||
class WebhookSchema(BaseModel):
|
||||
url: SSRFProtectedUrlRelaxed
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookSchema(url="http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
|
||||
class TestRealWorldURLs:
|
||||
"""Tests with real-world webhook URLs."""
|
||||
|
||||
def test_slack_webhook(self) -> None:
|
||||
"""Test Slack webhook URL."""
|
||||
url = (
|
||||
"https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXX"
|
||||
)
|
||||
assert is_safe_url(url) is True
|
||||
|
||||
def test_discord_webhook(self) -> None:
|
||||
"""Test Discord webhook URL."""
|
||||
url = "https://discord.com/api/webhooks/123456789012345678/abcdefghijklmnopqrstuvwxyz"
|
||||
assert is_safe_url(url) is True
|
||||
|
||||
def test_webhook_site(self) -> None:
|
||||
"""Test webhook.site URL."""
|
||||
url = "https://webhook.site/unique-id"
|
||||
assert is_safe_url(url) is True
|
||||
|
||||
def test_ngrok_url(self) -> None:
|
||||
"""Test ngrok URL (should be safe as it's public)."""
|
||||
url = "https://abc123.ngrok.io/webhook"
|
||||
assert is_safe_url(url) is True
|
||||
@@ -1768,6 +1768,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self,
|
||||
messages: Sequence[BaseMessage],
|
||||
tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None,
|
||||
*,
|
||||
allow_fetching_images: bool = True,
|
||||
) -> int:
|
||||
"""Calculate num tokens for `gpt-3.5-turbo` and `gpt-4` with `tiktoken` package.
|
||||
|
||||
@@ -1783,6 +1785,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of `dict`, `BaseModel`, function, or `BaseTool`
|
||||
to be converted to tool schemas.
|
||||
allow_fetching_images: Whether to allow fetching images for token counting.
|
||||
"""
|
||||
# TODO: Count bound tools as part of input.
|
||||
if tools is not None:
|
||||
@@ -1827,11 +1830,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif val["type"] == "image_url":
|
||||
if val["image_url"].get("detail") == "low":
|
||||
num_tokens += 85
|
||||
else:
|
||||
elif allow_fetching_images:
|
||||
image_size = _url_to_size(val["image_url"]["url"])
|
||||
if not image_size:
|
||||
continue
|
||||
num_tokens += _count_image_tokens(*image_size)
|
||||
else:
|
||||
pass
|
||||
# Tool/function call token counting is not documented by OpenAI.
|
||||
# This is an approximation.
|
||||
elif val["type"] == "function":
|
||||
@@ -3528,10 +3533,62 @@ def _url_to_size(image_source: str) -> tuple[int, int] | None:
|
||||
"`pip install -U httpx`."
|
||||
)
|
||||
return None
|
||||
response = httpx.get(image_source)
|
||||
response.raise_for_status()
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
|
||||
# Validate URL for SSRF protection
|
||||
try:
|
||||
from langchain_core._security._ssrf_protection import validate_safe_url
|
||||
|
||||
validate_safe_url(image_source, allow_private=False, allow_http=True)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"SSRF protection not available. "
|
||||
"Update langchain-core to get SSRF protection."
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Image URL failed SSRF validation: %s", e)
|
||||
return None
|
||||
|
||||
# Set reasonable limits to prevent resource exhaustion
|
||||
# Timeout prevents indefinite hangs on slow/malicious servers
|
||||
timeout = 5.0 # seconds
|
||||
# Max size matches OpenAI's 50 MB payload limit
|
||||
max_size = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
image_source,
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check response size before loading into memory
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length and int(content_length) > max_size:
|
||||
logger.warning(
|
||||
"Image URL exceeds maximum size limit of %d bytes", max_size
|
||||
)
|
||||
return None
|
||||
|
||||
# Also check actual content size
|
||||
if len(response.content) > max_size:
|
||||
logger.warning(
|
||||
"Image URL exceeds maximum size limit of %d bytes", max_size
|
||||
)
|
||||
return None
|
||||
|
||||
# close things (context managers)
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Image URL request timed out after %s seconds", timeout)
|
||||
return None
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning("Image URL returned HTTP error: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch or process image from URL: %s", e)
|
||||
return None
|
||||
|
||||
if _is_b64(image_source):
|
||||
_, encoded = image_source.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
|
||||
Reference in New Issue
Block a user