mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
core[patch]: utils for adding/subtracting usage metadata (#27203)
This commit is contained in:
parent
e3920f2320
commit
e3e9ee8398
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Literal, Optional, Union
|
import operator
|
||||||
|
from typing import Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
from typing_extensions import NotRequired, Self, TypedDict
|
from typing_extensions import NotRequired, Self, TypedDict
|
||||||
@ -27,6 +28,7 @@ from langchain_core.messages.tool import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
|
from langchain_core.utils.usage import _dict_int_op
|
||||||
|
|
||||||
|
|
||||||
class InputTokenDetails(TypedDict, total=False):
|
class InputTokenDetails(TypedDict, total=False):
|
||||||
@ -432,17 +434,9 @@ def add_ai_message_chunks(
|
|||||||
|
|
||||||
# Token usage
|
# Token usage
|
||||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||||
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
|
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||||
input_tokens=0, output_tokens=0, total_tokens=0
|
|
||||||
)
|
|
||||||
for other in others:
|
for other in others:
|
||||||
if other.usage_metadata is not None:
|
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
|
||||||
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
|
|
||||||
usage_metadata_["output_tokens"] += other.usage_metadata[
|
|
||||||
"output_tokens"
|
|
||||||
]
|
|
||||||
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
|
|
||||||
usage_metadata: Optional[UsageMetadata] = usage_metadata_
|
|
||||||
else:
|
else:
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
|
|
||||||
@ -455,3 +449,115 @@ def add_ai_message_chunks(
|
|||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
id=left.id,
|
id=left.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_usage(
|
||||||
|
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||||
|
) -> UsageMetadata:
|
||||||
|
"""Recursively add two UsageMetadata objects.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.messages.ai import add_usage
|
||||||
|
|
||||||
|
left = UsageMetadata(
|
||||||
|
input_tokens=5,
|
||||||
|
output_tokens=0,
|
||||||
|
total_tokens=5,
|
||||||
|
input_token_details=InputTokenDetails(cache_read=3)
|
||||||
|
)
|
||||||
|
right = UsageMetadata(
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=10,
|
||||||
|
total_tokens=10,
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=4)
|
||||||
|
)
|
||||||
|
|
||||||
|
add_usage(left, right)
|
||||||
|
|
||||||
|
results in
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
UsageMetadata(
|
||||||
|
input_tokens=5,
|
||||||
|
output_tokens=10,
|
||||||
|
total_tokens=15,
|
||||||
|
input_token_details=InputTokenDetails(cache_read=3),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=4)
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not (left or right):
|
||||||
|
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||||
|
if not (left and right):
|
||||||
|
return cast(UsageMetadata, left or right)
|
||||||
|
|
||||||
|
return UsageMetadata(
|
||||||
|
**cast(
|
||||||
|
UsageMetadata,
|
||||||
|
_dict_int_op(
|
||||||
|
cast(dict, left),
|
||||||
|
cast(dict, right),
|
||||||
|
operator.add,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_usage(
|
||||||
|
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||||
|
) -> UsageMetadata:
|
||||||
|
"""Recursively subtract two UsageMetadata objects.
|
||||||
|
|
||||||
|
Token counts cannot be negative so the actual operation is max(left - right, 0).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.messages.ai import subtract_usage
|
||||||
|
|
||||||
|
left = UsageMetadata(
|
||||||
|
input_tokens=5,
|
||||||
|
output_tokens=10,
|
||||||
|
total_tokens=15,
|
||||||
|
input_token_details=InputTokenDetails(cache_read=4)
|
||||||
|
)
|
||||||
|
right = UsageMetadata(
|
||||||
|
input_tokens=3,
|
||||||
|
output_tokens=8,
|
||||||
|
total_tokens=11,
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=4)
|
||||||
|
)
|
||||||
|
|
||||||
|
subtract_usage(left, right)
|
||||||
|
|
||||||
|
results in
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
UsageMetadata(
|
||||||
|
input_tokens=2,
|
||||||
|
output_tokens=2,
|
||||||
|
total_tokens=4,
|
||||||
|
input_token_details=InputTokenDetails(cache_read=4),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not (left or right):
|
||||||
|
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||||
|
if not (left and right):
|
||||||
|
return cast(UsageMetadata, left or right)
|
||||||
|
|
||||||
|
return UsageMetadata(
|
||||||
|
**cast(
|
||||||
|
UsageMetadata,
|
||||||
|
_dict_int_op(
|
||||||
|
cast(dict, left),
|
||||||
|
cast(dict, right),
|
||||||
|
(lambda le, ri: max(le - ri, 0)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
37
libs/core/langchain_core/utils/usage.py
Normal file
37
libs/core/langchain_core/utils/usage.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_int_op(
|
||||||
|
left: dict,
|
||||||
|
right: dict,
|
||||||
|
op: Callable[[int, int], int],
|
||||||
|
*,
|
||||||
|
default: int = 0,
|
||||||
|
depth: int = 0,
|
||||||
|
max_depth: int = 100,
|
||||||
|
) -> dict:
|
||||||
|
if depth >= max_depth:
|
||||||
|
msg = f"{max_depth=} exceeded, unable to combine dicts."
|
||||||
|
raise ValueError(msg)
|
||||||
|
combined: dict = {}
|
||||||
|
for k in set(left).union(right):
|
||||||
|
if isinstance(left.get(k, default), int) and isinstance(
|
||||||
|
right.get(k, default), int
|
||||||
|
):
|
||||||
|
combined[k] = op(left.get(k, default), right.get(k, default))
|
||||||
|
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
|
||||||
|
combined[k] = _dict_int_op(
|
||||||
|
left.get(k, {}),
|
||||||
|
right.get(k, {}),
|
||||||
|
op,
|
||||||
|
default=default,
|
||||||
|
depth=depth + 1,
|
||||||
|
max_depth=max_depth,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
types = [type(d[k]) for d in (left, right) if k in d]
|
||||||
|
msg = (
|
||||||
|
f"Unknown value types: {types}. Only dict and int values are supported."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
return combined
|
@ -1,5 +1,13 @@
|
|||||||
from langchain_core.load import dumpd, load
|
from langchain_core.load import dumpd, load
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||||
|
from langchain_core.messages.ai import (
|
||||||
|
InputTokenDetails,
|
||||||
|
OutputTokenDetails,
|
||||||
|
UsageMetadata,
|
||||||
|
add_ai_message_chunks,
|
||||||
|
add_usage,
|
||||||
|
subtract_usage,
|
||||||
|
)
|
||||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
@ -92,3 +100,99 @@ def test_serdes_message_chunk() -> None:
|
|||||||
actual = dumpd(chunk)
|
actual = dumpd(chunk)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
assert load(actual) == chunk
|
assert load(actual) == chunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_usage_both_none() -> None:
|
||||||
|
result = add_usage(None, None)
|
||||||
|
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_usage_one_none() -> None:
|
||||||
|
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||||
|
result = add_usage(usage, None)
|
||||||
|
assert result == usage
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_usage_both_present() -> None:
|
||||||
|
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||||
|
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||||
|
result = add_usage(usage1, usage2)
|
||||||
|
assert result == UsageMetadata(input_tokens=15, output_tokens=30, total_tokens=45)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_usage_with_details() -> None:
|
||||||
|
usage1 = UsageMetadata(
|
||||||
|
input_tokens=10,
|
||||||
|
output_tokens=20,
|
||||||
|
total_tokens=30,
|
||||||
|
input_token_details=InputTokenDetails(audio=5),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=10),
|
||||||
|
)
|
||||||
|
usage2 = UsageMetadata(
|
||||||
|
input_tokens=5,
|
||||||
|
output_tokens=10,
|
||||||
|
total_tokens=15,
|
||||||
|
input_token_details=InputTokenDetails(audio=3),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=5),
|
||||||
|
)
|
||||||
|
result = add_usage(usage1, usage2)
|
||||||
|
assert result["input_token_details"]["audio"] == 8
|
||||||
|
assert result["output_token_details"]["reasoning"] == 15
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtract_usage_both_none() -> None:
|
||||||
|
result = subtract_usage(None, None)
|
||||||
|
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtract_usage_one_none() -> None:
|
||||||
|
usage = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||||
|
result = subtract_usage(usage, None)
|
||||||
|
assert result == usage
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtract_usage_both_present() -> None:
|
||||||
|
usage1 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||||
|
usage2 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||||
|
result = subtract_usage(usage1, usage2)
|
||||||
|
assert result == UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||||
|
|
||||||
|
|
||||||
|
def test_subtract_usage_with_negative_result() -> None:
|
||||||
|
usage1 = UsageMetadata(input_tokens=5, output_tokens=10, total_tokens=15)
|
||||||
|
usage2 = UsageMetadata(input_tokens=10, output_tokens=20, total_tokens=30)
|
||||||
|
result = subtract_usage(usage1, usage2)
|
||||||
|
assert result == UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_ai_message_chunks_usage() -> None:
|
||||||
|
chunks = [
|
||||||
|
AIMessageChunk(content="", usage_metadata=None),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata=UsageMetadata(
|
||||||
|
input_tokens=2, output_tokens=3, total_tokens=5
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata=UsageMetadata(
|
||||||
|
input_tokens=2,
|
||||||
|
output_tokens=3,
|
||||||
|
total_tokens=5,
|
||||||
|
input_token_details=InputTokenDetails(audio=1, cache_read=1),
|
||||||
|
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
combined = add_ai_message_chunks(*chunks)
|
||||||
|
assert combined == AIMessageChunk(
|
||||||
|
content="",
|
||||||
|
usage_metadata=UsageMetadata(
|
||||||
|
input_tokens=4,
|
||||||
|
output_tokens=6,
|
||||||
|
total_tokens=10,
|
||||||
|
input_token_details=InputTokenDetails(audio=1, cache_read=1),
|
||||||
|
output_token_details=OutputTokenDetails(audio=1, reasoning=2),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
38
libs/core/tests/unit_tests/utils/test_usage.py
Normal file
38
libs/core/tests/unit_tests/utils/test_usage.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.utils.usage import _dict_int_op
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_int_op_add() -> None:
|
||||||
|
left = {"a": 1, "b": 2}
|
||||||
|
right = {"b": 3, "c": 4}
|
||||||
|
result = _dict_int_op(left, right, lambda x, y: x + y)
|
||||||
|
assert result == {"a": 1, "b": 5, "c": 4}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_int_op_subtract() -> None:
|
||||||
|
left = {"a": 5, "b": 10}
|
||||||
|
right = {"a": 2, "b": 3, "c": 1}
|
||||||
|
result = _dict_int_op(left, right, lambda x, y: max(x - y, 0))
|
||||||
|
assert result == {"a": 3, "b": 7, "c": 0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_int_op_nested() -> None:
|
||||||
|
left = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||||
|
right = {"a": 2, "b": {"c": 1, "e": 4}}
|
||||||
|
result = _dict_int_op(left, right, lambda x, y: x + y)
|
||||||
|
assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_int_op_max_depth_exceeded() -> None:
|
||||||
|
left = {"a": {"b": {"c": 1}}}
|
||||||
|
right = {"a": {"b": {"c": 2}}}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_dict_int_op(left, right, lambda x, y: x + y, max_depth=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_int_op_invalid_types() -> None:
|
||||||
|
left = {"a": 1, "b": "string"}
|
||||||
|
right = {"a": 2, "b": 3}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_dict_int_op(left, right, lambda x, y: x + y)
|
Loading…
Reference in New Issue
Block a user