Compare commits

...

2 Commits

Author SHA1 Message Date
Chester Curme
06e2719384 format 2024-06-25 12:17:33 -04:00
Chester Curme
94e37e4964 fix regression and add test 2024-06-25 12:13:21 -04:00
2 changed files with 32 additions and 17 deletions

View File

@@ -760,27 +760,30 @@ class BaseChatOpenAI(BaseChatModel):
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
for val in value:
if isinstance(val, str) or val["type"] == "text":
text = val["text"] if isinstance(val, dict) else val
num_tokens += len(encoding.encode(text))
elif val["type"] == "image_url":
if val["image_url"].get("detail") == "low":
num_tokens += 85
try:
if isinstance(value, list):
for val in value:
if isinstance(val, str) or val["type"] == "text":
text = val["text"] if isinstance(val, dict) else val
num_tokens += len(encoding.encode(text))
elif val["type"] == "image_url":
if val["image_url"].get("detail") == "low":
num_tokens += 85
else:
image_size = _url_to_size(val["image_url"]["url"])
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
else:
image_size = _url_to_size(val["image_url"]["url"])
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
else:
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
except ValueError:
num_tokens += len(encoding.encode(str(value)))
else:
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(value))
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant

View File

@@ -84,3 +84,15 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
]
actual = model.get_num_tokens_from_messages([message])
assert expected == actual
def test_token_counting(self, model: BaseChatModel) -> None:
user_message = HumanMessage("What is the value of magic_function(3)?")
assert model.get_num_tokens_from_messages([user_message]) > 0
tool_calls = [{"name": "magic_function", "args": {"input": 3}, "id": "test_id"}]
ai_message_1 = AIMessage(content="test", tool_calls=tool_calls)
assert model.get_num_tokens_from_messages([ai_message_1]) > 0
ai_message_2 = AIMessage(
content="", # blank content
tool_calls=tool_calls,
)
assert model.get_num_tokens_from_messages([ai_message_2]) > 0