diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 63036baf048..727a0045ffb 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -1,5 +1,6 @@ import json -from typing import Any, Literal, Optional, Union +import operator +from typing import Any, Literal, Optional, Union, cast from pydantic import model_validator from typing_extensions import NotRequired, Self, TypedDict @@ -27,6 +28,7 @@ from langchain_core.messages.tool import ( ) from langchain_core.utils._merge import merge_dicts, merge_lists from langchain_core.utils.json import parse_partial_json +from langchain_core.utils.usage import _dict_int_op class InputTokenDetails(TypedDict, total=False): @@ -432,17 +434,9 @@ def add_ai_message_chunks( # Token usage if left.usage_metadata or any(o.usage_metadata is not None for o in others): - usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata( - input_tokens=0, output_tokens=0, total_tokens=0 - ) + usage_metadata: Optional[UsageMetadata] = left.usage_metadata for other in others: - if other.usage_metadata is not None: - usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"] - usage_metadata_["output_tokens"] += other.usage_metadata[ - "output_tokens" - ] - usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"] - usage_metadata: Optional[UsageMetadata] = usage_metadata_ + usage_metadata = add_usage(usage_metadata, other.usage_metadata) else: usage_metadata = None @@ -455,3 +449,115 @@ def add_ai_message_chunks( usage_metadata=usage_metadata, id=left.id, ) + + +def add_usage( + left: Optional[UsageMetadata], right: Optional[UsageMetadata] +) -> UsageMetadata: + """Recursively add two UsageMetadata objects. + + Example: + .. code-block:: python + + from langchain_core.messages.ai import add_usage + + left = UsageMetadata( + input_tokens=5, + output_tokens=0, + total_tokens=5, + input_token_details=InputTokenDetails(cache_read=3) + ) + right = UsageMetadata( + input_tokens=0, + output_tokens=10, + total_tokens=10, + output_token_details=OutputTokenDetails(reasoning=4) + ) + + add_usage(left, right) + + results in + + .. code-block:: python + + UsageMetadata( + input_tokens=5, + output_tokens=10, + total_tokens=15, + input_token_details=InputTokenDetails(cache_read=3), + output_token_details=OutputTokenDetails(reasoning=4) + ) + + """ + if not (left or right): + return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0) + if not (left and right): + return cast(UsageMetadata, left or right) + + return UsageMetadata( + **cast( + UsageMetadata, + _dict_int_op( + cast(dict, left), + cast(dict, right), + operator.add, + ), + ) + ) + + +def subtract_usage( + left: Optional[UsageMetadata], right: Optional[UsageMetadata] +) -> UsageMetadata: + """Recursively subtract two UsageMetadata objects. + + Token counts cannot be negative so the actual operation is max(left - right, 0). + + Example: + .. code-block:: python + + from langchain_core.messages.ai import subtract_usage + + left = UsageMetadata( + input_tokens=5, + output_tokens=10, + total_tokens=15, + input_token_details=InputTokenDetails(cache_read=4) + ) + right = UsageMetadata( + input_tokens=3, + output_tokens=8, + total_tokens=11, + output_token_details=OutputTokenDetails(reasoning=4) + ) + + subtract_usage(left, right) + + results in + + .. code-block:: python + + UsageMetadata( + input_tokens=2, + output_tokens=2, + total_tokens=4, + input_token_details=InputTokenDetails(cache_read=4), + output_token_details=OutputTokenDetails(reasoning=0) + ) + + """ + if not (left or right): + return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0) + if not (left and right): + return cast(UsageMetadata, left or right) + + return UsageMetadata( + **cast( + UsageMetadata, + _dict_int_op( + cast(dict, left), + cast(dict, right), + (lambda le, ri: max(le - ri, 0)), + ), + ) + ) diff --git a/libs/core/langchain_core/utils/usage.py b/libs/core/langchain_core/utils/usage.py new file mode 100644 index 00000000000..ce198e268ba --- /dev/null +++ b/libs/core/langchain_core/utils/usage.py @@ -0,0 +1,37 @@ +from typing import Callable + + +def _dict_int_op( + left: dict, + right: dict, + op: Callable[[int, int], int], + *, + default: int = 0, + depth: int = 0, + max_depth: int = 100, +) -> dict: + if depth >= max_depth: + msg = f"{max_depth=} exceeded, unable to combine dicts." + raise ValueError(msg) + combined: dict = {} + for k in set(left).union(right): + if isinstance(left.get(k, default), int) and isinstance( + right.get(k, default), int + ): + combined[k] = op(left.get(k, default), right.get(k, default)) + elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict): + combined[k] = _dict_int_op( + left.get(k, {}), + right.get(k, {}), + op, + default=default, + depth=depth + 1, + max_depth=max_depth, + ) + else: + types = [type(d[k]) for d in (left, right) if k in d] + msg = ( + f"Unknown value types: {types}. Only dict and int values are supported." + ) + raise ValueError(msg) + return combined diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index 1de40acf94c..d36d0347128 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -1,5 +1,13 @@ from langchain_core.load import dumpd, load from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, + add_ai_message_chunks, + add_usage, + subtract_usage, +) from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk @@ -92,3 +100,99 @@ def test_serdes_message_chunk() -> None: actual = dumpd(chunk) assert actual == expected assert load(actual) == chunk + + +def test_add_usage_both_none() -> None: + result = add_usage(None, None) + assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0) + + +def test_add_usage_one_none() -> None: + usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + result = add_usage(usage, None) + assert result == usage + + +def test_add_usage_both_present() -> None: + usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15) + result = add_usage(usage1, usage2) + assert result == UsageMetadata(input_tokens=15, output_tokens=30, total_tokens=45) + + +def test_add_usage_with_details() -> None: + usage1 = UsageMetadata( + input_tokens=10, + output_tokens=20, + total_tokens=30, + input_token_details=InputTokenDetails(audio=5), + output_token_details=OutputTokenDetails(reasoning=10), + ) + usage2 = UsageMetadata( + input_tokens=5, + output_tokens=10, + total_tokens=15, + input_token_details=InputTokenDetails(audio=3), + output_token_details=OutputTokenDetails(reasoning=5), + ) + result = add_usage(usage1, usage2) + assert result["input_token_details"]["audio"] == 8 + assert result["output_token_details"]["reasoning"] == 15 + + +def test_subtract_usage_both_none() -> None: + result = subtract_usage(None, None) + assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0) + + +def test_subtract_usage_one_none() -> None: + usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + result = subtract_usage(usage, None) + assert result == usage + + +def test_subtract_usage_both_present() -> None: + usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15) + result = subtract_usage(usage1, usage2) + assert result == UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15) + + +def test_subtract_usage_with_negative_result() -> None: + usage1 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15) + usage2 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30) + result = subtract_usage(usage1, usage2) + assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0) + + +def test_add_ai_message_chunks_usage() -> None: + chunks = [ + AIMessageChunk(content="", usage_metadata=None), + AIMessageChunk( + content="", + usage_metadata=UsageMetadata( + input_tokens=2, output_tokens=3, total_tokens=5 + ), + ), + AIMessageChunk( + content="", + usage_metadata=UsageMetadata( + input_tokens=2, + output_tokens=3, + total_tokens=5, + input_token_details=InputTokenDetails(audio=1, cache_read=1), + output_token_details=OutputTokenDetails(audio=1, reasoning=2), + ), + ), + ] + combined = add_ai_message_chunks(*chunks) + assert combined == AIMessageChunk( + content="", + usage_metadata=UsageMetadata( + input_tokens=4, + output_tokens=6, + total_tokens=10, + input_token_details=InputTokenDetails(audio=1, cache_read=1), + output_token_details=OutputTokenDetails(audio=1, reasoning=2), + ), + ) diff --git a/libs/core/tests/unit_tests/utils/test_usage.py b/libs/core/tests/unit_tests/utils/test_usage.py new file mode 100644 index 00000000000..0d845d00789 --- /dev/null +++ b/libs/core/tests/unit_tests/utils/test_usage.py @@ -0,0 +1,38 @@ +import pytest + +from langchain_core.utils.usage import _dict_int_op + + +def test_dict_int_op_add() -> None: + left = {"a": 1, "b": 2} + right = {"b": 3, "c": 4} + result = _dict_int_op(left, right, lambda x, y: x + y) + assert result == {"a": 1, "b": 5, "c": 4} + + +def test_dict_int_op_subtract() -> None: + left = {"a": 5, "b": 10} + right = {"a": 2, "b": 3, "c": 1} + result = _dict_int_op(left, right, lambda x, y: max(x - y, 0)) + assert result == {"a": 3, "b": 7, "c": 0} + + +def test_dict_int_op_nested() -> None: + left = {"a": 1, "b": {"c": 2, "d": 3}} + right = {"a": 2, "b": {"c": 1, "e": 4}} + result = _dict_int_op(left, right, lambda x, y: x + y) + assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}} + + +def test_dict_int_op_max_depth_exceeded() -> None: + left = {"a": {"b": {"c": 1}}} + right = {"a": {"b": {"c": 2}}} + with pytest.raises(ValueError): + _dict_int_op(left, right, lambda x, y: x + y, max_depth=2) + + +def test_dict_int_op_invalid_types() -> None: + left = {"a": 1, "b": "string"} + right = {"a": 2, "b": 3} + with pytest.raises(ValueError): + _dict_int_op(left, right, lambda x, y: x + y)