mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
core[patch]: add util for approximate token counting (#30373)
This commit is contained in:
@@ -18,6 +18,7 @@ from langchain_core.messages import (
|
||||
from langchain_core.messages.utils import (
|
||||
convert_to_messages,
|
||||
convert_to_openai_messages,
|
||||
count_tokens_approximately,
|
||||
filter_messages,
|
||||
merge_message_runs,
|
||||
trim_messages,
|
||||
@@ -976,3 +977,130 @@ def test_convert_to_openai_messages_developer() -> None:
|
||||
]
|
||||
result = convert_to_openai_messages(messages)
|
||||
assert result == [{"role": "developer", "content": "a"}] * 2
|
||||
|
||||
|
||||
def test_count_tokens_approximately_empty_messages() -> None:
|
||||
# Test with empty message list
|
||||
assert count_tokens_approximately([]) == 0
|
||||
|
||||
# Test with empty content
|
||||
messages = [HumanMessage(content="")]
|
||||
# 4 role chars -> 1 + 3 = 4 tokens
|
||||
assert count_tokens_approximately(messages) == 4
|
||||
|
||||
|
||||
def test_count_tokens_approximately_with_names() -> None:
|
||||
messages = [
|
||||
# 5 chars + 4 role chars -> 3 + 3 = 6 tokens
|
||||
# (with name: extra 4 name chars, so total = 4 + 3 = 7 tokens)
|
||||
HumanMessage(content="Hello", name="user"),
|
||||
# 8 chars + 9 role chars -> 5 + 3 = 8 tokens
|
||||
# (with name: extra 9 name chars, so total = 7 + 3 = 10 tokens)
|
||||
AIMessage(content="Hi there", name="assistant"),
|
||||
]
|
||||
# With names included (default)
|
||||
assert count_tokens_approximately(messages) == 17
|
||||
|
||||
# Without names
|
||||
without_names = count_tokens_approximately(messages, count_name=False)
|
||||
assert without_names == 14
|
||||
|
||||
|
||||
def test_count_tokens_approximately_openai_format() -> None:
|
||||
# same as test_count_tokens_approximately_with_names, but in OpenAI format
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello", "name": "user"},
|
||||
{"role": "assistant", "content": "Hi there", "name": "assistant"},
|
||||
]
|
||||
# With names included (default)
|
||||
assert count_tokens_approximately(messages) == 17
|
||||
|
||||
# Without names
|
||||
without_names = count_tokens_approximately(messages, count_name=False)
|
||||
assert without_names == 14
|
||||
|
||||
|
||||
def test_count_tokens_approximately_string_content() -> None:
|
||||
messages = [
|
||||
# 5 chars + 4 role chars -> 3 + 3 = 6 tokens
|
||||
HumanMessage(content="Hello"),
|
||||
# 8 chars + 9 role chars -> 5 + 3 = 8 tokens
|
||||
AIMessage(content="Hi there"),
|
||||
# 12 chars + 4 role chars -> 4 + 3 = 7 tokens
|
||||
HumanMessage(content="How are you?"),
|
||||
]
|
||||
assert count_tokens_approximately(messages) == 21
|
||||
|
||||
|
||||
def test_count_tokens_approximately_list_content() -> None:
|
||||
messages = [
|
||||
# '[{"foo": "bar"}]' -> 16 chars + 4 role chars -> 5 + 3 = 8 tokens
|
||||
HumanMessage(content=[{"foo": "bar"}]),
|
||||
# '[{"test": 123}]' -> 15 chars + 9 role chars -> 6 + 3 = 9 tokens
|
||||
AIMessage(content=[{"test": 123}]),
|
||||
]
|
||||
assert count_tokens_approximately(messages) == 17
|
||||
|
||||
|
||||
def test_count_tokens_approximately_tool_calls() -> None:
|
||||
tool_calls = [{"name": "test_tool", "args": {"foo": "bar"}, "id": "1"}]
|
||||
messages = [
|
||||
# tool calls json -> 79 chars + 9 role chars -> 22 + 3 = 25 tokens
|
||||
AIMessage(content="", tool_calls=tool_calls),
|
||||
# 15 chars + 4 role chars -> 5 + 3 = 8 tokens
|
||||
HumanMessage(content="Regular message"),
|
||||
]
|
||||
assert count_tokens_approximately(messages) == 33
|
||||
# AI message w/ both content and tool calls
|
||||
# 94 chars + 9 role chars -> 26 + 3 = 29 tokens
|
||||
messages = [
|
||||
AIMessage(content="Regular message", tool_calls=tool_calls),
|
||||
]
|
||||
assert count_tokens_approximately(messages) == 29
|
||||
|
||||
|
||||
def test_count_tokens_approximately_custom_token_length() -> None:
|
||||
messages = [
|
||||
# 11 chars + 4 role chars -> (4 tokens of length 4 / 8 tokens of length 2) + 3
|
||||
HumanMessage(content="Hello world"),
|
||||
# 7 chars + 9 role chars -> (4 tokens of length 4 / 8 tokens of length 2) + 3
|
||||
AIMessage(content="Testing"),
|
||||
]
|
||||
assert count_tokens_approximately(messages, chars_per_token=4) == 14
|
||||
assert count_tokens_approximately(messages, chars_per_token=2) == 22
|
||||
|
||||
|
||||
def test_count_tokens_approximately_large_message_content() -> None:
|
||||
# Test with large content to ensure no issues
|
||||
large_text = "x" * 10000
|
||||
messages = [HumanMessage(content=large_text)]
|
||||
# 10,000 chars + 4 role chars -> 2501 + 3 = 2504 tokens
|
||||
assert count_tokens_approximately(messages) == 2504
|
||||
|
||||
|
||||
def test_count_tokens_approximately_large_number_of_messages() -> None:
|
||||
# Test with large content to ensure no issues
|
||||
messages = [HumanMessage(content="x")] * 1_000
|
||||
# 1 chars + 4 role chars -> 2 + 3 = 5 tokens
|
||||
assert count_tokens_approximately(messages) == 5_000
|
||||
|
||||
|
||||
def test_count_tokens_approximately_mixed_content_types() -> None:
|
||||
# Test with a variety of content types in the same message list
|
||||
tool_calls = [{"name": "test_tool", "args": {"foo": "bar"}, "id": "1"}]
|
||||
messages = [
|
||||
# 13 chars + 6 role chars -> 5 + 3 = 8 tokens
|
||||
SystemMessage(content="System prompt"),
|
||||
# '[{"foo": "bar"}]' -> 16 chars + 4 role chars -> 5 + 3 = 8 tokens
|
||||
HumanMessage(content=[{"foo": "bar"}]),
|
||||
# tool calls json -> 79 chars + 9 role chars -> 22 + 3 = 25 tokens
|
||||
AIMessage(content="", tool_calls=tool_calls),
|
||||
# 13 chars + 4 role chars + 9 name chars + 1 tool call ID char ->
|
||||
# 7 + 3 = 10 tokens
|
||||
ToolMessage(content="Tool response", name="test_tool", tool_call_id="1"),
|
||||
]
|
||||
token_count = count_tokens_approximately(messages)
|
||||
assert token_count == 51
|
||||
|
||||
# Ensure that count is consistent if we do one message at a time
|
||||
assert sum(count_tokens_approximately([m]) for m in messages) == token_count
|
||||
|
Reference in New Issue
Block a user