mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 21:43:44 +00:00
openai[patch]: image token counting (#23147)
Resolves #23000 --------- Co-authored-by: isaac hershenson <ihershenson@hmc.edu> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from math import ceil
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -26,6 +29,7 @@ from typing import (
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
@@ -736,7 +740,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
**Requirements**: You must have the ``pillow`` installed if you want to count
|
||||
image tokens if you are specifying the image as a base64 string, and you must
|
||||
have both ``pillow`` and ``httpx`` installed if you are specifying the image
|
||||
as a URL. If these aren't installed image inputs will be ignored in token
|
||||
counting.
|
||||
|
||||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if sys.version_info[1] <= 7:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
@@ -753,7 +763,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}. See "
|
||||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens"
|
||||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" # noqa: E501
|
||||
" for information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
@@ -761,9 +771,27 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if isinstance(value, list):
|
||||
for val in value:
|
||||
if isinstance(val, str) or val["type"] == "text":
|
||||
text = val["text"] if isinstance(val, dict) else val
|
||||
num_tokens += len(encoding.encode(text))
|
||||
elif val["type"] == "image_url":
|
||||
if val["image_url"].get("detail") == "low":
|
||||
num_tokens += 85
|
||||
else:
|
||||
image_size = _url_to_size(val["image_url"]["url"])
|
||||
if not image_size:
|
||||
continue
|
||||
num_tokens += _count_image_tokens(*image_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block type\n\n{val}"
|
||||
)
|
||||
else:
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
@@ -1541,3 +1569,75 @@ def _lc_invalid_tool_call_to_openai_tool_call(
|
||||
"arguments": invalid_tool_call["args"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]:
|
||||
try:
|
||||
from PIL import Image # type: ignore[import]
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Unable to count image tokens. To count image tokens please install "
|
||||
"`pip install -U pillow httpx`."
|
||||
)
|
||||
return None
|
||||
if _is_url(image_source):
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Unable to count image tokens. To count image tokens please install "
|
||||
"`pip install -U httpx`."
|
||||
)
|
||||
return None
|
||||
response = httpx.get(image_source)
|
||||
response.raise_for_status()
|
||||
width, height = Image.open(BytesIO(response.content)).size
|
||||
return width, height
|
||||
elif _is_b64(image_source):
|
||||
_, encoded = image_source.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
width, height = Image.open(BytesIO(data)).size
|
||||
return width, height
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _count_image_tokens(width: int, height: int) -> int:
|
||||
# Reference: https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
width, height = _resize(width, height)
|
||||
h = ceil(height / 512)
|
||||
w = ceil(width / 512)
|
||||
return (170 * h * w) + 85
|
||||
|
||||
|
||||
def _is_url(s: str) -> bool:
|
||||
try:
|
||||
result = urlparse(s)
|
||||
return all([result.scheme, result.netloc])
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to parse URL: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_b64(s: str) -> bool:
|
||||
return s.startswith("data:image")
|
||||
|
||||
|
||||
def _resize(width: int, height: int) -> Tuple[int, int]:
|
||||
# larger side must be <= 2048
|
||||
if width > 2048 or height > 2048:
|
||||
if width > height:
|
||||
height = (height * 2048) // width
|
||||
width = 2048
|
||||
else:
|
||||
width = (width * 2048) // height
|
||||
height = 2048
|
||||
# smaller side must be <= 768
|
||||
if width > 768 and height > 768:
|
||||
if width > height:
|
||||
width = (width * 768) // height
|
||||
height = 768
|
||||
else:
|
||||
height = (width * 768) // height
|
||||
width = 768
|
||||
return width, height
|
||||
|
Reference in New Issue
Block a user