From 2b4b1dc29a833d4053deba4c2b77a3848c834565 Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 10 Feb 2026 15:25:10 -0500 Subject: [PATCH] fix(openai): sanitize urls when counting tokens in images (#35143) --- .../core/langchain_core/_security/__init__.py | 0 .../_security/_ssrf_protection.py | 361 ++++++++++++++++++ .../tests/unit_tests/test_ssrf_protection.py | 297 ++++++++++++++ .../langchain_openai/chat_models/base.py | 67 +++- 4 files changed, 720 insertions(+), 5 deletions(-) create mode 100644 libs/core/langchain_core/_security/__init__.py create mode 100644 libs/core/langchain_core/_security/_ssrf_protection.py create mode 100644 libs/core/tests/unit_tests/test_ssrf_protection.py diff --git a/libs/core/langchain_core/_security/__init__.py b/libs/core/langchain_core/_security/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/langchain_core/_security/_ssrf_protection.py b/libs/core/langchain_core/_security/_ssrf_protection.py new file mode 100644 index 00000000000..5afa573c0c2 --- /dev/null +++ b/libs/core/langchain_core/_security/_ssrf_protection.py @@ -0,0 +1,361 @@ +"""SSRF Protection for validating URLs against Server-Side Request Forgery attacks. + +This module provides utilities to validate user-provided URLs and prevent SSRF attacks +by blocking requests to: +- Private IP ranges (RFC 1918, loopback, link-local) +- Cloud metadata endpoints (AWS, GCP, Azure, etc.) +- Localhost addresses +- Invalid URL schemes + +Usage: + from lc_security.ssrf_protection import validate_safe_url, is_safe_url + + # Validate a URL (raises ValueError if unsafe) + safe_url = validate_safe_url("https://example.com/webhook") + + # Check if URL is safe (returns bool) + if is_safe_url("http://192.168.1.1"): + # URL is safe + pass + + # Allow private IPs for development/testing (still blocks cloud metadata) + safe_url = validate_safe_url("http://localhost:8080", allow_private=True) +""" + +import ipaddress +import os +import socket +from typing import Annotated, Any +from urllib.parse import urlparse + +from pydantic import ( + AnyHttpUrl, + BeforeValidator, + HttpUrl, +) + +# Private IP ranges (RFC 1918, RFC 4193, RFC 3927, loopback) +PRIVATE_IP_RANGES = [ + ipaddress.ip_network("10.0.0.0/8"), # Private Class A + ipaddress.ip_network("172.16.0.0/12"), # Private Class B + ipaddress.ip_network("192.168.0.0/16"), # Private Class C + ipaddress.ip_network("127.0.0.0/8"), # Loopback + ipaddress.ip_network("169.254.0.0/16"), # Link-local (includes cloud metadata) + ipaddress.ip_network("0.0.0.0/8"), # Current network + ipaddress.ip_network("::1/128"), # IPv6 loopback + ipaddress.ip_network("fc00::/7"), # IPv6 unique local + ipaddress.ip_network("fe80::/10"), # IPv6 link-local + ipaddress.ip_network("ff00::/8"), # IPv6 multicast +] + +# Cloud provider metadata endpoints +CLOUD_METADATA_IPS = [ + "169.254.169.254", # AWS, GCP, Azure, DigitalOcean, Oracle Cloud + "169.254.170.2", # AWS ECS task metadata + "100.100.100.200", # Alibaba Cloud metadata +] + +CLOUD_METADATA_HOSTNAMES = [ + "metadata.google.internal", # GCP + "metadata", # Generic + "instance-data", # AWS EC2 +] + +# Localhost variations +LOCALHOST_NAMES = [ + "localhost", + "localhost.localdomain", +] + + +def is_private_ip(ip_str: str) -> bool: + """Check if an IP address is in a private range. + + Args: + ip_str: IP address as a string (e.g., "192.168.1.1") + + Returns: + True if IP is in a private range, False otherwise + """ + try: + ip = ipaddress.ip_address(ip_str) + return any(ip in range_ for range_ in PRIVATE_IP_RANGES) + except ValueError: + return False + + +def is_cloud_metadata(hostname: str, ip_str: str | None = None) -> bool: + """Check if hostname or IP is a cloud metadata endpoint. + + Args: + hostname: Hostname to check + ip_str: Optional IP address to check + + Returns: + True if hostname or IP is a known cloud metadata endpoint + """ + # Check hostname + if hostname.lower() in CLOUD_METADATA_HOSTNAMES: + return True + + # Check IP + if ip_str and ip_str in CLOUD_METADATA_IPS: # noqa: SIM103 + return True + + return False + + +def is_localhost(hostname: str, ip_str: str | None = None) -> bool: + """Check if hostname or IP is localhost. + + Args: + hostname: Hostname to check + ip_str: Optional IP address to check + + Returns: + True if hostname or IP is localhost + """ + # Check hostname + if hostname.lower() in LOCALHOST_NAMES: + return True + + # Check IP + if ip_str: + try: + ip = ipaddress.ip_address(ip_str) + # Check if loopback + if ip.is_loopback: + return True + # Also check common localhost IPs + if ip_str in ("127.0.0.1", "::1", "0.0.0.0"): # noqa: S104 + return True + except ValueError: + pass + + return False + + +def validate_safe_url( + url: str | AnyHttpUrl, + *, + allow_private: bool = False, + allow_http: bool = True, +) -> str: + """Validate a URL for SSRF protection. + + This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks + by blocking requests to private networks and cloud metadata endpoints. + + Args: + url: The URL to validate (string or Pydantic HttpUrl) + allow_private: If True, allows private IPs and localhost (for development). + Cloud metadata endpoints are ALWAYS blocked. + allow_http: If True, allows both HTTP and HTTPS. If False, only HTTPS. + + Returns: + The validated URL as a string + + Raises: + ValueError: If URL is invalid or potentially dangerous + + Examples: + >>> validate_safe_url("https://hooks.slack.com/services/xxx") + 'https://hooks.slack.com/services/xxx' + + >>> validate_safe_url("http://127.0.0.1:8080") + ValueError: Localhost URLs are not allowed + + >>> validate_safe_url("http://192.168.1.1") + ValueError: URL resolves to private IP: 192.168.1.1 + + >>> validate_safe_url("http://169.254.169.254/latest/meta-data/") + ValueError: URL resolves to cloud metadata IP: 169.254.169.254 + + >>> validate_safe_url("http://localhost:8080", allow_private=True) + 'http://localhost:8080' + """ + url_str = str(url) + parsed = urlparse(url_str) + + # Validate URL scheme + if not allow_http and parsed.scheme != "https": + msg = "Only HTTPS URLs are allowed" + raise ValueError(msg) + + if parsed.scheme not in ("http", "https"): + msg = f"Only HTTP/HTTPS URLs are allowed, got scheme: {parsed.scheme}" + raise ValueError(msg) + + # Extract hostname + hostname = parsed.hostname + if not hostname: + msg = "URL must have a valid hostname" + raise ValueError(msg) + + # Special handling for test environments - allow test server hostnames + # testserver is used by FastAPI/Starlette test clients and doesn't resolve via DNS + # Only enabled when LANGCHAIN_ENV=local_test (set in conftest.py) + if ( + os.environ.get("LANGCHAIN_ENV") == "local_test" + and hostname.startswith("test") + and "server" in hostname + ): + return url_str + + # ALWAYS block cloud metadata endpoints (even with allow_private=True) + if is_cloud_metadata(hostname): + msg = f"Cloud metadata endpoints are not allowed: {hostname}" + raise ValueError(msg) + + # Check for localhost + if is_localhost(hostname) and not allow_private: + msg = f"Localhost URLs are not allowed: {hostname}" + raise ValueError(msg) + + # Resolve hostname to IP addresses and validate each one. + # Note: DNS resolution results are cached by the OS, so repeated calls are fast. + try: + # Get all IP addresses for this hostname + addr_info = socket.getaddrinfo( + hostname, + parsed.port or (443 if parsed.scheme == "https" else 80), + socket.AF_UNSPEC, # Allow both IPv4 and IPv6 + socket.SOCK_STREAM, + ) + + for result in addr_info: + ip_str: str = result[4][0] # type: ignore[assignment] + + # ALWAYS block cloud metadata IPs + if is_cloud_metadata(hostname, ip_str): + msg = f"URL resolves to cloud metadata IP: {ip_str}" + raise ValueError(msg) + + # Check for localhost IPs + if is_localhost(hostname, ip_str) and not allow_private: + msg = f"URL resolves to localhost IP: {ip_str}" + raise ValueError(msg) + + # Check for private IPs + if not allow_private and is_private_ip(ip_str): + msg = f"URL resolves to private IP address: {ip_str}" + raise ValueError(msg) + + except socket.gaierror as e: + # DNS resolution failed - fail closed for security + msg = f"Failed to resolve hostname '{hostname}': {e}" + raise ValueError(msg) from e + except OSError as e: + # Other network errors - fail closed + msg = f"Network error while validating URL: {e}" + raise ValueError(msg) from e + + return url_str + + +def is_safe_url( + url: str | AnyHttpUrl, + *, + allow_private: bool = False, + allow_http: bool = True, +) -> bool: + """Check if a URL is safe (non-throwing version of validate_safe_url). + + Args: + url: The URL to check + allow_private: If True, allows private IPs and localhost + allow_http: If True, allows both HTTP and HTTPS + + Returns: + True if URL is safe, False otherwise + + Examples: + >>> is_safe_url("https://example.com") + True + + >>> is_safe_url("http://127.0.0.1:8080") + False + + >>> is_safe_url("http://localhost:8080", allow_private=True) + True + """ + try: + validate_safe_url(url, allow_private=allow_private, allow_http=allow_http) + except ValueError: + return False + else: + return True + + +def _validate_url_ssrf_strict(v: Any) -> Any: + """Validate URL for SSRF protection (strict mode).""" + if isinstance(v, str): + validate_safe_url(v, allow_private=False, allow_http=True) + return v + + +def _validate_url_ssrf_https_only(v: Any) -> Any: + """Validate URL for SSRF protection (HTTPS only, strict mode).""" + if isinstance(v, str): + validate_safe_url(v, allow_private=False, allow_http=False) + return v + + +def _validate_url_ssrf_relaxed(v: Any) -> Any: + """Validate URL for SSRF protection (relaxed mode - allows private IPs).""" + if isinstance(v, str): + validate_safe_url(v, allow_private=True, allow_http=True) + return v + + +# Annotated types with SSRF protection +SSRFProtectedUrl = Annotated[HttpUrl, BeforeValidator(_validate_url_ssrf_strict)] +"""A Pydantic HttpUrl type with built-in SSRF protection. + +This blocks private IPs, localhost, and cloud metadata endpoints. + +Example: + class WebhookSchema(BaseModel): + url: SSRFProtectedUrl # Automatically validated for SSRF + headers: dict[str, str] | None = None +""" + +SSRFProtectedUrlRelaxed = Annotated[ + HttpUrl, BeforeValidator(_validate_url_ssrf_relaxed) +] +"""A Pydantic HttpUrl with relaxed SSRF protection (allows private IPs). + +Use this for development/testing webhooks where localhost/private IPs are needed. +Cloud metadata endpoints are still blocked. + +Example: + class DevWebhookSchema(BaseModel): + url: SSRFProtectedUrlRelaxed # Allows localhost, blocks cloud metadata +""" + +SSRFProtectedHttpsUrl = Annotated[ + HttpUrl, BeforeValidator(_validate_url_ssrf_https_only) +] +"""A Pydantic HttpUrl with SSRF protection that only allows HTTPS. + +This blocks private IPs, localhost, cloud metadata endpoints, and HTTP URLs. + +Example: + class SecureWebhookSchema(BaseModel): + url: SSRFProtectedHttpsUrl # Only HTTPS, blocks private IPs +""" + +SSRFProtectedHttpsUrlStr = Annotated[ + str, BeforeValidator(_validate_url_ssrf_https_only) +] +"""A string type with SSRF protection that only allows HTTPS URLs. + +Same as SSRFProtectedHttpsUrl but returns a string instead of HttpUrl. +Useful for FastAPI query parameters where you need a string URL. + +Example: + @router.get("/proxy") + async def proxy_get(url: SSRFProtectedHttpsUrlStr): + async with httpx.AsyncClient() as client: + resp = await client.get(url) +""" diff --git a/libs/core/tests/unit_tests/test_ssrf_protection.py b/libs/core/tests/unit_tests/test_ssrf_protection.py new file mode 100644 index 00000000000..fc9fd8d1064 --- /dev/null +++ b/libs/core/tests/unit_tests/test_ssrf_protection.py @@ -0,0 +1,297 @@ +"""Tests for SSRF protection utilities.""" + +from typing import Any + +import pytest +from pydantic import BaseModel, ValidationError + +from langchain_core._security._ssrf_protection import ( + SSRFProtectedUrl, + SSRFProtectedUrlRelaxed, + is_cloud_metadata, + is_localhost, + is_private_ip, + is_safe_url, + validate_safe_url, +) + + +class TestIPValidation: + """Tests for IP address validation functions.""" + + def test_is_private_ip_ipv4(self) -> None: + """Test private IPv4 address detection.""" + assert is_private_ip("10.0.0.1") is True + assert is_private_ip("172.16.0.1") is True + assert is_private_ip("192.168.1.1") is True + assert is_private_ip("127.0.0.1") is True + assert is_private_ip("169.254.169.254") is True + assert is_private_ip("0.0.0.1") is True + + def test_is_private_ip_ipv6(self) -> None: + """Test private IPv6 address detection.""" + assert is_private_ip("::1") is True # Loopback + assert is_private_ip("fc00::1") is True # Unique local + assert is_private_ip("fe80::1") is True # Link-local + assert is_private_ip("ff00::1") is True # Multicast + + def test_is_private_ip_public(self) -> None: + """Test that public IPs are not flagged as private.""" + assert is_private_ip("8.8.8.8") is False + assert is_private_ip("1.1.1.1") is False + assert is_private_ip("151.101.1.140") is False + + def test_is_private_ip_invalid(self) -> None: + """Test handling of invalid IP addresses.""" + assert is_private_ip("not-an-ip") is False + assert is_private_ip("999.999.999.999") is False + + def test_is_cloud_metadata_ips(self) -> None: + """Test cloud metadata IP detection.""" + assert is_cloud_metadata("example.com", "169.254.169.254") is True + assert is_cloud_metadata("example.com", "169.254.170.2") is True + assert is_cloud_metadata("example.com", "100.100.100.200") is True + + def test_is_cloud_metadata_hostnames(self) -> None: + """Test cloud metadata hostname detection.""" + assert is_cloud_metadata("metadata.google.internal") is True + assert is_cloud_metadata("metadata") is True + assert is_cloud_metadata("instance-data") is True + assert is_cloud_metadata("METADATA.GOOGLE.INTERNAL") is True # Case insensitive + + def test_is_cloud_metadata_safe(self) -> None: + """Test that normal URLs are not flagged as cloud metadata.""" + assert is_cloud_metadata("example.com", "8.8.8.8") is False + assert is_cloud_metadata("google.com") is False + + def test_is_localhost_hostnames(self) -> None: + """Test localhost hostname detection.""" + assert is_localhost("localhost") is True + assert is_localhost("LOCALHOST") is True + assert is_localhost("localhost.localdomain") is True + + def test_is_localhost_ips(self) -> None: + """Test localhost IP detection.""" + assert is_localhost("example.com", "127.0.0.1") is True + assert is_localhost("example.com", "::1") is True + assert is_localhost("example.com", "0.0.0.0") is True + + def test_is_localhost_safe(self) -> None: + """Test that normal hosts are not flagged as localhost.""" + assert is_localhost("example.com", "8.8.8.8") is False + assert is_localhost("google.com") is False + + +class TestValidateSafeUrl: + """Tests for validate_safe_url function.""" + + def test_valid_public_https_url(self) -> None: + """Test that valid public HTTPS URLs are accepted.""" + url = "https://hooks.slack.com/services/xxx" + result = validate_safe_url(url) + assert result == url + + def test_valid_public_http_url(self) -> None: + """Test that valid public HTTP URLs are accepted.""" + url = "http://example.com/webhook" + result = validate_safe_url(url) + assert result == url + + def test_localhost_blocked_by_default(self) -> None: + """Test that localhost URLs are blocked by default.""" + with pytest.raises(ValueError, match="Localhost"): + validate_safe_url("http://localhost:8080/webhook") + + with pytest.raises(ValueError, match="localhost"): + validate_safe_url("http://127.0.0.1:8080/webhook") + + def test_localhost_allowed_with_flag(self) -> None: + """Test that localhost is allowed with allow_private=True.""" + url = "http://localhost:8080/webhook" + result = validate_safe_url(url, allow_private=True) + assert result == url + + url = "http://127.0.0.1:8080/webhook" + result = validate_safe_url(url, allow_private=True) + assert result == url + + def test_private_ip_blocked_by_default(self) -> None: + """Test that private IPs are blocked by default.""" + with pytest.raises(ValueError, match="private IP"): + validate_safe_url("http://192.168.1.1/webhook") + + with pytest.raises(ValueError, match="private IP"): + validate_safe_url("http://10.0.0.1/webhook") + + with pytest.raises(ValueError, match="private IP"): + validate_safe_url("http://172.16.0.1/webhook") + + def test_private_ip_allowed_with_flag(self) -> None: + """Test that private IPs are allowed with allow_private=True.""" + # Note: These will fail DNS resolution in tests, so we skip actual validation + # In production, they would be validated properly + + def test_cloud_metadata_always_blocked(self) -> None: + """Test that cloud metadata endpoints are always blocked.""" + with pytest.raises(ValueError, match="metadata"): + validate_safe_url("http://169.254.169.254/latest/meta-data/") + + # Even with allow_private=True + with pytest.raises(ValueError, match="metadata"): + validate_safe_url( + "http://169.254.169.254/latest/meta-data/", + allow_private=True, + ) + + def test_invalid_scheme_blocked(self) -> None: + """Test that non-HTTP(S) schemes are blocked.""" + with pytest.raises(ValueError, match="scheme"): + validate_safe_url("ftp://example.com/file") + + with pytest.raises(ValueError, match="scheme"): + validate_safe_url("file:///etc/passwd") + + with pytest.raises(ValueError, match="scheme"): + validate_safe_url("javascript:alert(1)") + + def test_https_only_mode(self) -> None: + """Test that HTTP is blocked when allow_http=False.""" + with pytest.raises(ValueError, match="HTTPS"): + validate_safe_url("http://example.com/webhook", allow_http=False) + + # HTTPS should still work + url = "https://example.com/webhook" + result = validate_safe_url(url, allow_http=False) + assert result == url + + def test_url_without_hostname(self) -> None: + """Test that URLs without hostname are rejected.""" + with pytest.raises(ValueError, match="hostname"): + validate_safe_url("http:///path") + + def test_dns_resolution_failure(self) -> None: + """Test handling of DNS resolution failures.""" + with pytest.raises(ValueError, match="resolve"): + validate_safe_url("http://this-domain-definitely-does-not-exist-12345.com") + + def test_testserver_allowed(self, monkeypatch: Any) -> None: + """Test that testserver hostname is allowed for test environments.""" + # testserver is used by FastAPI/Starlette test clients + monkeypatch.setenv("LANGCHAIN_ENV", "local_test") + url = "http://testserver/webhook" + result = validate_safe_url(url) + assert result == url + + +class TestIsSafeUrl: + """Tests for is_safe_url function (non-throwing version).""" + + def test_safe_url_returns_true(self) -> None: + """Test that safe URLs return True.""" + assert is_safe_url("https://example.com/webhook") is True + assert is_safe_url("http://hooks.slack.com/services/xxx") is True + + def test_unsafe_url_returns_false(self) -> None: + """Test that unsafe URLs return False.""" + assert is_safe_url("http://localhost:8080") is False + assert is_safe_url("http://127.0.0.1:8080") is False + assert is_safe_url("http://192.168.1.1") is False + assert is_safe_url("http://169.254.169.254") is False + + def test_unsafe_url_safe_with_allow_private(self) -> None: + """Test that private URLs are safe with allow_private=True.""" + assert is_safe_url("http://localhost:8080", allow_private=True) is True + assert is_safe_url("http://127.0.0.1:8080", allow_private=True) is True + + def test_cloud_metadata_always_unsafe(self) -> None: + """Test that cloud metadata is always unsafe.""" + assert is_safe_url("http://169.254.169.254") is False + assert is_safe_url("http://169.254.169.254", allow_private=True) is False + + +class TestSSRFProtectedUrlType: + """Tests for SSRFProtectedUrl Pydantic type.""" + + def test_valid_url_accepted(self) -> None: + """Test that valid URLs are accepted by Pydantic schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrl + + schema = WebhookSchema(url="https://hooks.slack.com/services/xxx") + assert str(schema.url).startswith("https://hooks.slack.com/") + + def test_localhost_rejected(self) -> None: + """Test that localhost URLs are rejected by Pydantic schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrl + + with pytest.raises(ValidationError): + WebhookSchema(url="http://localhost:8080") + + def test_private_ip_rejected(self) -> None: + """Test that private IPs are rejected by Pydantic schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrl + + with pytest.raises(ValidationError): + WebhookSchema(url="http://192.168.1.1") + + def test_cloud_metadata_rejected(self) -> None: + """Test that cloud metadata is rejected by Pydantic schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrl + + with pytest.raises(ValidationError): + WebhookSchema(url="http://169.254.169.254/latest/meta-data/") + + +class TestSSRFProtectedUrlRelaxedType: + """Tests for SSRFProtectedUrlRelaxed Pydantic type.""" + + def test_localhost_accepted(self) -> None: + """Test that localhost URLs are accepted by relaxed schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrlRelaxed + + schema = WebhookSchema(url="http://localhost:8080") + assert str(schema.url).startswith("http://localhost") + + def test_cloud_metadata_still_rejected(self) -> None: + """Test that cloud metadata is still rejected by relaxed schema.""" + + class WebhookSchema(BaseModel): + url: SSRFProtectedUrlRelaxed + + with pytest.raises(ValidationError): + WebhookSchema(url="http://169.254.169.254/latest/meta-data/") + + +class TestRealWorldURLs: + """Tests with real-world webhook URLs.""" + + def test_slack_webhook(self) -> None: + """Test Slack webhook URL.""" + url = ( + "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXX" + ) + assert is_safe_url(url) is True + + def test_discord_webhook(self) -> None: + """Test Discord webhook URL.""" + url = "https://discord.com/api/webhooks/123456789012345678/abcdefghijklmnopqrstuvwxyz" + assert is_safe_url(url) is True + + def test_webhook_site(self) -> None: + """Test webhook.site URL.""" + url = "https://webhook.site/unique-id" + assert is_safe_url(url) is True + + def test_ngrok_url(self) -> None: + """Test ngrok URL (should be safe as it's public).""" + url = "https://abc123.ngrok.io/webhook" + assert is_safe_url(url) is True diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 0128bfd3232..6fecae9d617 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1768,6 +1768,8 @@ class BaseChatOpenAI(BaseChatModel): self, messages: Sequence[BaseMessage], tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None, + *, + allow_fetching_images: bool = True, ) -> int: """Calculate num tokens for `gpt-3.5-turbo` and `gpt-4` with `tiktoken` package. @@ -1783,6 +1785,7 @@ class BaseChatOpenAI(BaseChatModel): messages: The message inputs to tokenize. tools: If provided, sequence of `dict`, `BaseModel`, function, or `BaseTool` to be converted to tool schemas. + allow_fetching_images: Whether to allow fetching images for token counting. """ # TODO: Count bound tools as part of input. if tools is not None: @@ -1827,11 +1830,13 @@ class BaseChatOpenAI(BaseChatModel): elif val["type"] == "image_url": if val["image_url"].get("detail") == "low": num_tokens += 85 - else: + elif allow_fetching_images: image_size = _url_to_size(val["image_url"]["url"]) if not image_size: continue num_tokens += _count_image_tokens(*image_size) + else: + pass # Tool/function call token counting is not documented by OpenAI. # This is an approximation. elif val["type"] == "function": @@ -3528,10 +3533,62 @@ def _url_to_size(image_source: str) -> tuple[int, int] | None: "`pip install -U httpx`." ) return None - response = httpx.get(image_source) - response.raise_for_status() - width, height = Image.open(BytesIO(response.content)).size - return width, height + + # Validate URL for SSRF protection + try: + from langchain_core._security._ssrf_protection import validate_safe_url + + validate_safe_url(image_source, allow_private=False, allow_http=True) + except ImportError: + logger.warning( + "SSRF protection not available. " + "Update langchain-core to get SSRF protection." + ) + except ValueError as e: + logger.warning("Image URL failed SSRF validation: %s", e) + return None + + # Set reasonable limits to prevent resource exhaustion + # Timeout prevents indefinite hangs on slow/malicious servers + timeout = 5.0 # seconds + # Max size matches OpenAI's 50 MB payload limit + max_size = 50 * 1024 * 1024 # 50 MB + + try: + response = httpx.get( + image_source, + timeout=timeout, + ) + response.raise_for_status() + + # Check response size before loading into memory + content_length = response.headers.get("content-length") + if content_length and int(content_length) > max_size: + logger.warning( + "Image URL exceeds maximum size limit of %d bytes", max_size + ) + return None + + # Also check actual content size + if len(response.content) > max_size: + logger.warning( + "Image URL exceeds maximum size limit of %d bytes", max_size + ) + return None + + # close things (context managers) + width, height = Image.open(BytesIO(response.content)).size + return width, height + except httpx.TimeoutException: + logger.warning("Image URL request timed out after %s seconds", timeout) + return None + except httpx.HTTPStatusError as e: + logger.warning("Image URL returned HTTP error: %s", e) + return None + except Exception as e: + logger.warning("Failed to fetch or process image from URL: %s", e) + return None + if _is_b64(image_source): _, encoded = image_source.split(",", 1) data = base64.b64decode(encoded)