core[patch]: utils for adding/subtracting usage metadata (#27203)

This commit is contained in:
Bagatur
2024-10-08 13:15:33 -07:00
committed by GitHub
parent e3920f2320
commit e3e9ee8398
4 changed files with 296 additions and 11 deletions

View File

@@ -1,5 +1,13 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
add_ai_message_chunks,
add_usage,
subtract_usage,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
@@ -92,3 +100,99 @@ def test_serdes_message_chunk() -> None:
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk
def test_add_usage_both_none() -> None:
result = add_usage(None, None)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_add_usage_one_none() -> None:
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = add_usage(usage, None)
assert result == usage
def test_add_usage_both_present() -> None:
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
result = add_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=15, output_tokens=30, total_tokens=45)
def test_add_usage_with_details() -> None:
usage1 = UsageMetadata(
input_tokens=10,
output_tokens=20,
total_tokens=30,
input_token_details=InputTokenDetails(audio=5),
output_token_details=OutputTokenDetails(reasoning=10),
)
usage2 = UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(audio=3),
output_token_details=OutputTokenDetails(reasoning=5),
)
result = add_usage(usage1, usage2)
assert result["input_token_details"]["audio"] == 8
assert result["output_token_details"]["reasoning"] == 15
def test_subtract_usage_both_none() -> None:
result = subtract_usage(None, None)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_subtract_usage_one_none() -> None:
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = subtract_usage(usage, None)
assert result == usage
def test_subtract_usage_both_present() -> None:
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
result = subtract_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
def test_subtract_usage_with_negative_result() -> None:
usage1 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
usage2 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
result = subtract_usage(usage1, usage2)
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
def test_add_ai_message_chunks_usage() -> None:
chunks = [
AIMessageChunk(content="", usage_metadata=None),
AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=2, output_tokens=3, total_tokens=5
),
),
AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=2,
output_tokens=3,
total_tokens=5,
input_token_details=InputTokenDetails(audio=1, cache_read=1),
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
),
),
]
combined = add_ai_message_chunks(*chunks)
assert combined == AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=4,
output_tokens=6,
total_tokens=10,
input_token_details=InputTokenDetails(audio=1, cache_read=1),
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
),
)

View File

@@ -0,0 +1,38 @@
import pytest
from langchain_core.utils.usage import _dict_int_op
def test_dict_int_op_add() -> None:
left = {"a": 1, "b": 2}
right = {"b": 3, "c": 4}
result = _dict_int_op(left, right, lambda x, y: x + y)
assert result == {"a": 1, "b": 5, "c": 4}
def test_dict_int_op_subtract() -> None:
left = {"a": 5, "b": 10}
right = {"a": 2, "b": 3, "c": 1}
result = _dict_int_op(left, right, lambda x, y: max(x - y, 0))
assert result == {"a": 3, "b": 7, "c": 0}
def test_dict_int_op_nested() -> None:
left = {"a": 1, "b": {"c": 2, "d": 3}}
right = {"a": 2, "b": {"c": 1, "e": 4}}
result = _dict_int_op(left, right, lambda x, y: x + y)
assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}}
def test_dict_int_op_max_depth_exceeded() -> None:
left = {"a": {"b": {"c": 1}}}
right = {"a": {"b": {"c": 2}}}
with pytest.raises(ValueError):
_dict_int_op(left, right, lambda x, y: x + y, max_depth=2)
def test_dict_int_op_invalid_types() -> None:
left = {"a": 1, "b": "string"}
right = {"a": 2, "b": 3}
with pytest.raises(ValueError):
_dict_int_op(left, right, lambda x, y: x + y)