core[patch]: add token counting callback handler (#30481)

Stripped-down version of
[OpenAICallbackHandler](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py)
that just tracks `AIMessage.usage_metadata`.

```python
from langchain_core.callbacks import get_usage_metadata_callback
from langgraph.prebuilt import create_react_agent

def get_weather(location: str) -> str:
    """Get the weather at a location."""
    return "It's sunny."

tools = [get_weather]
agent = create_react_agent("openai:gpt-4o-mini", tools)

with get_usage_metadata_callback() as cb:
    result = await agent.ainvoke({"messages": "What's the weather in Boston?"})
    print(cb.usage_metadata)
```
This commit is contained in:
ccurme 2025-03-25 18:16:39 -04:00 committed by GitHub
parent 1d2b1d8e5e
commit 4b9e2e51f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 201 additions and 0 deletions

View File

@ -43,6 +43,10 @@ from langchain_core.callbacks.manager import (
) )
from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.callbacks.usage import (
UsageMetadataCallbackHandler,
get_usage_metadata_callback,
)
__all__ = [ __all__ = [
"dispatch_custom_event", "dispatch_custom_event",
@ -77,4 +81,6 @@ __all__ = [
"StdOutCallbackHandler", "StdOutCallbackHandler",
"StreamingStdOutCallbackHandler", "StreamingStdOutCallbackHandler",
"FileCallbackHandler", "FileCallbackHandler",
"UsageMetadataCallbackHandler",
"get_usage_metadata_callback",
] ]

View File

@ -0,0 +1,108 @@
"""Callback Handler that tracks AIMessage.usage_metadata."""
import threading
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.messages.ai import UsageMetadata, add_usage
from langchain_core.outputs import ChatGeneration, LLMResult
class UsageMetadataCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks AIMessage.usage_metadata.
Example:
.. code-block:: python
from langchain.chat_models import init_chat_model
from langchain_core.callbacks import UsageMetadataCallbackHandler
llm = init_chat_model(model="openai:gpt-4o-mini")
callback = UsageMetadataCallbackHandler()
results = llm.batch(["Hello", "Goodbye"], config={"callbacks": [callback]})
print(callback.usage_metadata)
.. code-block:: none
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
.. versionadded:: 0.3.49
""" # noqa: E501
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self.usage_metadata: Optional[UsageMetadata] = None
def __repr__(self) -> str:
return str(self.usage_metadata)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
try:
generation = response.generations[0][0]
except IndexError:
generation = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
else:
usage_metadata = None
except AttributeError:
usage_metadata = None
else:
usage_metadata = None
# update shared state behind lock
with self._lock:
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
@contextmanager
def get_usage_metadata_callback(
name: str = "usage_metadata_callback",
) -> Generator[UsageMetadataCallbackHandler, None, None]:
"""Get context manager for tracking usage metadata across chat model calls using
``AIMessage.usage_metadata``.
Args:
name (str): The name of the context variable. Defaults to
``"usage_metadata_callback"``.
Example:
.. code-block:: python
from langchain.chat_models import init_chat_model
from langchain_core.callbacks import get_usage_metadata_callback
llm = init_chat_model(model="openai:gpt-4o-mini")
with get_usage_metadata_callback() as cb:
llm.invoke("Hello")
llm.invoke("Goodbye")
print(cb.usage_metadata)
.. code-block:: none
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
.. versionadded:: 0.3.49
""" # noqa: E501
from langchain_core.tracers.context import register_configure_hook
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
ContextVar(name, default=None)
)
register_configure_hook(usage_metadata_callback_var, True)
cb = UsageMetadataCallbackHandler()
usage_metadata_callback_var.set(cb)
yield cb
usage_metadata_callback_var.set(None)

View File

@ -33,6 +33,8 @@ EXPECTED_ALL = [
"FileCallbackHandler", "FileCallbackHandler",
"adispatch_custom_event", "adispatch_custom_event",
"dispatch_custom_event", "dispatch_custom_event",
"UsageMetadataCallbackHandler",
"get_usage_metadata_callback",
] ]

View File

@ -0,0 +1,85 @@
from itertools import cycle
from langchain_core.callbacks import (
UsageMetadataCallbackHandler,
get_usage_metadata_callback,
)
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.messages import AIMessage
from langchain_core.messages.ai import (
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
add_usage,
)
usage1 = UsageMetadata(
input_tokens=1,
output_tokens=2,
total_tokens=3,
)
usage2 = UsageMetadata(
input_tokens=4,
output_tokens=5,
total_tokens=9,
)
usage3 = UsageMetadata(
input_tokens=10,
output_tokens=20,
total_tokens=30,
input_token_details=InputTokenDetails(audio=5),
output_token_details=OutputTokenDetails(reasoning=10),
)
usage4 = UsageMetadata(
input_tokens=5,
output_tokens=10,
total_tokens=15,
input_token_details=InputTokenDetails(audio=3),
output_token_details=OutputTokenDetails(reasoning=5),
)
messages = [
AIMessage("Response 1", usage_metadata=usage1),
AIMessage("Response 2", usage_metadata=usage2),
AIMessage("Response 3", usage_metadata=usage3),
AIMessage("Response 4", usage_metadata=usage4),
]
def test_usage_callback() -> None:
llm = GenericFakeChatModel(messages=cycle(messages))
# Test context manager
with get_usage_metadata_callback() as cb:
_ = llm.invoke("Message 1")
_ = llm.invoke("Message 2")
total_1_2 = add_usage(usage1, usage2)
assert cb.usage_metadata == total_1_2
_ = llm.invoke("Message 3")
_ = llm.invoke("Message 4")
total_3_4 = add_usage(usage3, usage4)
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
# Test via config
callback = UsageMetadataCallbackHandler()
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
assert callback.usage_metadata == total_1_2
async def test_usage_callback_async() -> None:
llm = GenericFakeChatModel(messages=cycle(messages))
# Test context manager
with get_usage_metadata_callback() as cb:
_ = await llm.ainvoke("Message 1")
_ = await llm.ainvoke("Message 2")
total_1_2 = add_usage(usage1, usage2)
assert cb.usage_metadata == total_1_2
_ = await llm.ainvoke("Message 3")
_ = await llm.ainvoke("Message 4")
total_3_4 = add_usage(usage3, usage4)
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
# Test via config
callback = UsageMetadataCallbackHandler()
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
assert callback.usage_metadata == total_1_2