Compare commits

...

2 Commits

Author SHA1 Message Date
Chester Curme
06e2719384 format 2024-06-25 12:17:33 -04:00
Chester Curme
94e37e4964 fix regression and add test 2024-06-25 12:13:21 -04:00
2 changed files with 32 additions and 17 deletions

View File

@@ -760,27 +760,30 @@ class BaseChatOpenAI(BaseChatModel):
for message in messages_dict: for message in messages_dict:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
if isinstance(value, list): try:
for val in value: if isinstance(value, list):
if isinstance(val, str) or val["type"] == "text": for val in value:
text = val["text"] if isinstance(val, dict) else val if isinstance(val, str) or val["type"] == "text":
num_tokens += len(encoding.encode(text)) text = val["text"] if isinstance(val, dict) else val
elif val["type"] == "image_url": num_tokens += len(encoding.encode(text))
if val["image_url"].get("detail") == "low": elif val["type"] == "image_url":
num_tokens += 85 if val["image_url"].get("detail") == "low":
num_tokens += 85
else:
image_size = _url_to_size(val["image_url"]["url"])
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
else: else:
image_size = _url_to_size(val["image_url"]["url"]) raise ValueError(
if not image_size: f"Unrecognized content block type\n\n{val}"
continue )
num_tokens += _count_image_tokens(*image_size) except ValueError:
else: num_tokens += len(encoding.encode(str(value)))
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
else: else:
# Cast str(value) in case the message value is not a string # Cast str(value) in case the message value is not a string
# This occurs with function messages # This occurs with function messages
num_tokens += len(encoding.encode(value)) num_tokens += len(encoding.encode(str(value)))
if key == "name": if key == "name":
num_tokens += tokens_per_name num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant # every reply is primed with <im_start>assistant

View File

@@ -84,3 +84,15 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
] ]
actual = model.get_num_tokens_from_messages([message]) actual = model.get_num_tokens_from_messages([message])
assert expected == actual assert expected == actual
def test_token_counting(self, model: BaseChatModel) -> None:
user_message = HumanMessage("What is the value of magic_function(3)?")
assert model.get_num_tokens_from_messages([user_message]) > 0
tool_calls = [{"name": "magic_function", "args": {"input": 3}, "id": "test_id"}]
ai_message_1 = AIMessage(content="test", tool_calls=tool_calls)
assert model.get_num_tokens_from_messages([ai_message_1]) > 0
ai_message_2 = AIMessage(
content="", # blank content
tool_calls=tool_calls,
)
assert model.get_num_tokens_from_messages([ai_message_2]) > 0