core[patch]: utils for adding/subtracting usage metadata (#27203)

This commit is contained in:
Bagatur
2024-10-08 13:15:33 -07:00
committed by GitHub
parent e3920f2320
commit e3e9ee8398
4 changed files with 296 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
import json
from typing import Any, Literal, Optional, Union
import operator
from typing import Any, Literal, Optional, Union, cast
from pydantic import model_validator
from typing_extensions import NotRequired, Self, TypedDict
@@ -27,6 +28,7 @@ from langchain_core.messages.tool import (
)
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.usage import _dict_int_op
class InputTokenDetails(TypedDict, total=False):
@@ -432,17 +434,9 @@ def add_ai_message_chunks(
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = left.usage_metadata
for other in others:
if other.usage_metadata is not None:
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata_["output_tokens"] += other.usage_metadata[
"output_tokens"
]
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
usage_metadata: Optional[UsageMetadata] = usage_metadata_
usage_metadata = add_usage(usage_metadata, other.usage_metadata)
else:
usage_metadata = None
@@ -455,3 +449,115 @@ def add_ai_message_chunks(
usage_metadata=usage_metadata,
id=left.id,
)
def add_usage(
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
) -> UsageMetadata:
"""Recursively add two UsageMetadata objects.
Example:
.. code-block:: python
from langchain_core.messages.ai import add_usage
left = UsageMetadata(
input_tokens=5,
output_tokens=0,
total_tokens=5,
input_token_details=InputTokenDetails(cache_read=3)
)
right = UsageMetadata(
input_tokens=0,
output_tokens=10,
total_tokens=10,
output_token_details=OutputTokenDetails(reasoning=4)
)
add_usage(left, right)
results in
.. code-block:: python
UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(cache_read=3),
output_token_details=OutputTokenDetails(reasoning=4)
)
"""
if not (left or right):
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
if not (left and right):
return cast(UsageMetadata, left or right)
return UsageMetadata(
**cast(
UsageMetadata,
_dict_int_op(
cast(dict, left),
cast(dict, right),
operator.add,
),
)
)
def subtract_usage(
left: Optional[UsageMetadata], right: Optional[UsageMetadata]
) -> UsageMetadata:
"""Recursively subtract two UsageMetadata objects.
Token counts cannot be negative so the actual operation is max(left - right, 0).
Example:
.. code-block:: python
from langchain_core.messages.ai import subtract_usage
left = UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(cache_read=4)
)
right = UsageMetadata(
input_tokens=3,
output_tokens=8,
total_tokens=11,
output_token_details=OutputTokenDetails(reasoning=4)
)
subtract_usage(left, right)
results in
.. code-block:: python
UsageMetadata(
input_tokens=2,
output_tokens=2,
total_tokens=4,
input_token_details=InputTokenDetails(cache_read=4),
output_token_details=OutputTokenDetails(reasoning=0)
)
"""
if not (left or right):
return UsageMetadata(input_tokens=0, output_tokens=0, total_tokens=0)
if not (left and right):
return cast(UsageMetadata, left or right)
return UsageMetadata(
**cast(
UsageMetadata,
_dict_int_op(
cast(dict, left),
cast(dict, right),
(lambda le, ri: max(le - ri, 0)),
),
)
)

View File

@@ -0,0 +1,37 @@
from typing import Callable
def _dict_int_op(
left: dict,
right: dict,
op: Callable[[int, int], int],
*,
default: int = 0,
depth: int = 0,
max_depth: int = 100,
) -> dict:
if depth >= max_depth:
msg = f"{max_depth=} exceeded, unable to combine dicts."
raise ValueError(msg)
combined: dict = {}
for k in set(left).union(right):
if isinstance(left.get(k, default), int) and isinstance(
right.get(k, default), int
):
combined[k] = op(left.get(k, default), right.get(k, default))
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
combined[k] = _dict_int_op(
left.get(k, {}),
right.get(k, {}),
op,
default=default,
depth=depth + 1,
max_depth=max_depth,
)
else:
types = [type(d[k]) for d in (left, right) if k in d]
msg = (
f"Unknown value types: {types}. Only dict and int values are supported."
)
raise ValueError(msg)
return combined