mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
core[patch]: utils for adding/subtracting usage metadata (#27203)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Union
|
||||
import operator
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
@@ -27,6 +28,7 @@ from langchain_core.messages.tool import (
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.usage import _dict_int_op
|
||||
|
||||
|
||||
class InputTokenDetails(TypedDict, total=False):
|
||||
@@ -432,17 +434,9 @@ def add_ai_message_chunks(
|
||||
|
||||
# Token usage
|
||||
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
|
||||
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
|
||||
for other in others:
|
||||
if other.usage_metadata is not None:
|
||||
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
|
||||
usage_metadata_["output_tokens"] += other.usage_metadata[
|
||||
"output_tokens"
|
||||
]
|
||||
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
|
||||
usage_metadata: Optional[UsageMetadata] = usage_metadata_
|
||||
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
@@ -455,3 +449,115 @@ def add_ai_message_chunks(
|
||||
usage_metadata=usage_metadata,
|
||||
id=left.id,
|
||||
)
|
||||
|
||||
|
||||
def add_usage(
|
||||
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||
) -> UsageMetadata:
|
||||
"""Recursively add two UsageMetadata objects.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.messages.ai import add_usage
|
||||
|
||||
left = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=0,
|
||||
total_tokens=5,
|
||||
input_token_details=InputTokenDetails(cache_read=3)
|
||||
)
|
||||
right = UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=10,
|
||||
total_tokens=10,
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
add_usage(left, right)
|
||||
|
||||
results in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(cache_read=3),
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
"""
|
||||
if not (left or right):
|
||||
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
if not (left and right):
|
||||
return cast(UsageMetadata, left or right)
|
||||
|
||||
return UsageMetadata(
|
||||
**cast(
|
||||
UsageMetadata,
|
||||
_dict_int_op(
|
||||
cast(dict, left),
|
||||
cast(dict, right),
|
||||
operator.add,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def subtract_usage(
|
||||
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
|
||||
) -> UsageMetadata:
|
||||
"""Recursively subtract two UsageMetadata objects.
|
||||
|
||||
Token counts cannot be negative so the actual operation is max(left - right, 0).
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.messages.ai import subtract_usage
|
||||
|
||||
left = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(cache_read=4)
|
||||
)
|
||||
right = UsageMetadata(
|
||||
input_tokens=3,
|
||||
output_tokens=8,
|
||||
total_tokens=11,
|
||||
output_token_details=OutputTokenDetails(reasoning=4)
|
||||
)
|
||||
|
||||
subtract_usage(left, right)
|
||||
|
||||
results in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
UsageMetadata(
|
||||
input_tokens=2,
|
||||
output_tokens=2,
|
||||
total_tokens=4,
|
||||
input_token_details=InputTokenDetails(cache_read=4),
|
||||
output_token_details=OutputTokenDetails(reasoning=0)
|
||||
)
|
||||
|
||||
"""
|
||||
if not (left or right):
|
||||
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
|
||||
if not (left and right):
|
||||
return cast(UsageMetadata, left or right)
|
||||
|
||||
return UsageMetadata(
|
||||
**cast(
|
||||
UsageMetadata,
|
||||
_dict_int_op(
|
||||
cast(dict, left),
|
||||
cast(dict, right),
|
||||
(lambda le, ri: max(le - ri, 0)),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
37
libs/core/langchain_core/utils/usage.py
Normal file
37
libs/core/langchain_core/utils/usage.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def _dict_int_op(
|
||||
left: dict,
|
||||
right: dict,
|
||||
op: Callable[[int, int], int],
|
||||
*,
|
||||
default: int = 0,
|
||||
depth: int = 0,
|
||||
max_depth: int = 100,
|
||||
) -> dict:
|
||||
if depth >= max_depth:
|
||||
msg = f"{max_depth=} exceeded, unable to combine dicts."
|
||||
raise ValueError(msg)
|
||||
combined: dict = {}
|
||||
for k in set(left).union(right):
|
||||
if isinstance(left.get(k, default), int) and isinstance(
|
||||
right.get(k, default), int
|
||||
):
|
||||
combined[k] = op(left.get(k, default), right.get(k, default))
|
||||
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
|
||||
combined[k] = _dict_int_op(
|
||||
left.get(k, {}),
|
||||
right.get(k, {}),
|
||||
op,
|
||||
default=default,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
else:
|
||||
types = [type(d[k]) for d in (left, right) if k in d]
|
||||
msg = (
|
||||
f"Unknown value types: {types}. Only dict and int values are supported."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return combined
|
Reference in New Issue
Block a user