mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(openai): sanitize urls when counting tokens in images (#35143)
This commit is contained in:
@@ -1768,6 +1768,8 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
self,
|
||||
messages: Sequence[BaseMessage],
|
||||
tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None,
|
||||
*,
|
||||
allow_fetching_images: bool = True,
|
||||
) -> int:
|
||||
"""Calculate num tokens for `gpt-3.5-turbo` and `gpt-4` with `tiktoken` package.
|
||||
|
||||
@@ -1783,6 +1785,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of `dict`, `BaseModel`, function, or `BaseTool`
|
||||
to be converted to tool schemas.
|
||||
allow_fetching_images: Whether to allow fetching images for token counting.
|
||||
"""
|
||||
# TODO: Count bound tools as part of input.
|
||||
if tools is not None:
|
||||
@@ -1827,11 +1830,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
elif val["type"] == "image_url":
|
||||
if val["image_url"].get("detail") == "low":
|
||||
num_tokens += 85
|
||||
else:
|
||||
elif allow_fetching_images:
|
||||
image_size = _url_to_size(val["image_url"]["url"])
|
||||
if not image_size:
|
||||
continue
|
||||
num_tokens += _count_image_tokens(*image_size)
|
||||
else:
|
||||
pass
|
||||
# Tool/function call token counting is not documented by OpenAI.
|
||||
# This is an approximation.
|
||||
elif val["type"] == "function":
|
||||
@@ -3528,10 +3533,62 @@ def _url_to_size(image_source: str) -> tuple[int, int] | None:
|
||||
"`pip install -U httpx`."
|
||||
)
|
||||
return None
|
||||
response = httpx.get(image_source)
|
||||
response.raise_for_status()
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
|
||||
# Validate URL for SSRF protection
|
||||
try:
|
||||
from langchain_core._security._ssrf_protection import validate_safe_url
|
||||
|
||||
validate_safe_url(image_source, allow_private=False, allow_http=True)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"SSRF protection not available. "
|
||||
"Update langchain-core to get SSRF protection."
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Image URL failed SSRF validation: %s", e)
|
||||
return None
|
||||
|
||||
# Set reasonable limits to prevent resource exhaustion
|
||||
# Timeout prevents indefinite hangs on slow/malicious servers
|
||||
timeout = 5.0 # seconds
|
||||
# Max size matches OpenAI's 50 MB payload limit
|
||||
max_size = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
image_source,
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check response size before loading into memory
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length and int(content_length) > max_size:
|
||||
logger.warning(
|
||||
"Image URL exceeds maximum size limit of %d bytes", max_size
|
||||
)
|
||||
return None
|
||||
|
||||
# Also check actual content size
|
||||
if len(response.content) > max_size:
|
||||
logger.warning(
|
||||
"Image URL exceeds maximum size limit of %d bytes", max_size
|
||||
)
|
||||
return None
|
||||
|
||||
# close things (context managers)
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Image URL request timed out after %s seconds", timeout)
|
||||
return None
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning("Image URL returned HTTP error: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch or process image from URL: %s", e)
|
||||
return None
|
||||
|
||||
if _is_b64(image_source):
|
||||
_, encoded = image_source.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
|
||||
Reference in New Issue
Block a user