openai[patch]: image token counting (#23147)

Resolves #23000

---------

Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Bagatur
2024-06-19 10:41:47 -07:00
committed by GitHub
parent b3e53ffca0
commit 0a4ee864e9
6 changed files with 487 additions and 101 deletions

View File

@@ -2,10 +2,13 @@
from __future__ import annotations
import base64
import json
import logging
import os
import sys
from io import BytesIO
from math import ceil
from operator import itemgetter
from typing import (
Any,
@@ -26,6 +29,7 @@ from typing import (
cast,
overload,
)
from urllib.parse import urlparse
import openai
import tiktoken
@@ -736,7 +740,13 @@ class BaseChatOpenAI(BaseChatModel):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
**Requirements**: You must have the ``pillow`` installed if you want to count
image tokens if you are specifying the image as a base64 string, and you must
have both ``pillow`` and ``httpx`` installed if you are specifying the image
as a URL. If these aren't installed image inputs will be ignored in token
counting.
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
@@ -753,7 +763,7 @@ class BaseChatOpenAI(BaseChatModel):
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}. See "
"https://platform.openai.com/docs/guides/text-generation/managing-tokens"
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" # noqa: E501
" for information on how messages are converted to tokens."
)
num_tokens = 0
@@ -761,9 +771,27 @@ class BaseChatOpenAI(BaseChatModel):
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if isinstance(value, list):
for val in value:
if isinstance(val, str) or val["type"] == "text":
text = val["text"] if isinstance(val, dict) else val
num_tokens += len(encoding.encode(text))
elif val["type"] == "image_url":
if val["image_url"].get("detail") == "low":
num_tokens += 85
else:
image_size = _url_to_size(val["image_url"]["url"])
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
else:
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
else:
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
@@ -1541,3 +1569,75 @@ def _lc_invalid_tool_call_to_openai_tool_call(
"arguments": invalid_tool_call["args"],
},
}
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]:
try:
from PIL import Image # type: ignore[import]
except ImportError:
logger.info(
"Unable to count image tokens. To count image tokens please install "
"`pip install -U pillow httpx`."
)
return None
if _is_url(image_source):
try:
import httpx
except ImportError:
logger.info(
"Unable to count image tokens. To count image tokens please install "
"`pip install -U httpx`."
)
return None
response = httpx.get(image_source)
response.raise_for_status()
width, height = Image.open(BytesIO(response.content)).size
return width, height
elif _is_b64(image_source):
_, encoded = image_source.split(",", 1)
data = base64.b64decode(encoded)
width, height = Image.open(BytesIO(data)).size
return width, height
else:
return None
def _count_image_tokens(width: int, height: int) -> int:
# Reference: https://platform.openai.com/docs/guides/vision/calculating-costs
width, height = _resize(width, height)
h = ceil(height / 512)
w = ceil(width / 512)
return (170 * h * w) + 85
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _is_b64(s: str) -> bool:
return s.startswith("data:image")
def _resize(width: int, height: int) -> Tuple[int, int]:
# larger side must be <= 2048
if width > 2048 or height > 2048:
if width > height:
height = (height * 2048) // width
width = 2048
else:
width = (width * 2048) // height
height = 2048
# smaller side must be <= 768
if width > 768 and height > 768:
if width > height:
width = (width * 768) // height
height = 768
else:
height = (width * 768) // height
width = 768
return width, height