mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
core[patch]: add token counting callback handler (#30481)
Stripped-down version of [OpenAICallbackHandler](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py) that just tracks `AIMessage.usage_metadata`. ```python from langchain_core.callbacks import get_usage_metadata_callback from langgraph.prebuilt import create_react_agent def get_weather(location: str) -> str: """Get the weather at a location.""" return "It's sunny." tools = [get_weather] agent = create_react_agent("openai:gpt-4o-mini", tools) with get_usage_metadata_callback() as cb: result = await agent.ainvoke({"messages": "What's the weather in Boston?"}) print(cb.usage_metadata) ```
This commit is contained in:
parent
1d2b1d8e5e
commit
4b9e2e51f3
@ -43,6 +43,10 @@ from langchain_core.callbacks.manager import (
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain_core.callbacks.usage import (
|
||||
UsageMetadataCallbackHandler,
|
||||
get_usage_metadata_callback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"dispatch_custom_event",
|
||||
@ -77,4 +81,6 @@ __all__ = [
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"FileCallbackHandler",
|
||||
"UsageMetadataCallbackHandler",
|
||||
"get_usage_metadata_callback",
|
||||
]
|
||||
|
108
libs/core/langchain_core/callbacks/usage.py
Normal file
108
libs/core/langchain_core/callbacks/usage.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Callback Handler that tracks AIMessage.usage_metadata."""
|
||||
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
|
||||
class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks AIMessage.usage_metadata.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.callbacks import UsageMetadataCallbackHandler
|
||||
|
||||
llm = init_chat_model(model="openai:gpt-4o-mini")
|
||||
|
||||
callback = UsageMetadataCallbackHandler()
|
||||
results = llm.batch(["Hello", "Goodbye"], config={"callbacks": [callback]})
|
||||
print(callback.usage_metadata)
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
|
||||
|
||||
.. versionadded:: 0.3.49
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self.usage_metadata: Optional[UsageMetadata] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.usage_metadata)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except IndexError:
|
||||
generation = None
|
||||
if isinstance(generation, ChatGeneration):
|
||||
try:
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage):
|
||||
usage_metadata = message.usage_metadata
|
||||
else:
|
||||
usage_metadata = None
|
||||
except AttributeError:
|
||||
usage_metadata = None
|
||||
else:
|
||||
usage_metadata = None
|
||||
|
||||
# update shared state behind lock
|
||||
with self._lock:
|
||||
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_usage_metadata_callback(
|
||||
name: str = "usage_metadata_callback",
|
||||
) -> Generator[UsageMetadataCallbackHandler, None, None]:
|
||||
"""Get context manager for tracking usage metadata across chat model calls using
|
||||
``AIMessage.usage_metadata``.
|
||||
|
||||
Args:
|
||||
name (str): The name of the context variable. Defaults to
|
||||
``"usage_metadata_callback"``.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.callbacks import get_usage_metadata_callback
|
||||
|
||||
llm = init_chat_model(model="openai:gpt-4o-mini")
|
||||
|
||||
with get_usage_metadata_callback() as cb:
|
||||
llm.invoke("Hello")
|
||||
llm.invoke("Goodbye")
|
||||
print(cb.usage_metadata)
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
|
||||
|
||||
.. versionadded:: 0.3.49
|
||||
""" # noqa: E501
|
||||
from langchain_core.tracers.context import register_configure_hook
|
||||
|
||||
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
|
||||
ContextVar(name, default=None)
|
||||
)
|
||||
register_configure_hook(usage_metadata_callback_var, True)
|
||||
cb = UsageMetadataCallbackHandler()
|
||||
usage_metadata_callback_var.set(cb)
|
||||
yield cb
|
||||
usage_metadata_callback_var.set(None)
|
@ -33,6 +33,8 @@ EXPECTED_ALL = [
|
||||
"FileCallbackHandler",
|
||||
"adispatch_custom_event",
|
||||
"dispatch_custom_event",
|
||||
"UsageMetadataCallbackHandler",
|
||||
"get_usage_metadata_callback",
|
||||
]
|
||||
|
||||
|
||||
|
85
libs/core/tests/unit_tests/callbacks/test_usage_callback.py
Normal file
85
libs/core/tests/unit_tests/callbacks/test_usage_callback.py
Normal file
@ -0,0 +1,85 @@
|
||||
from itertools import cycle
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
UsageMetadataCallbackHandler,
|
||||
get_usage_metadata_callback,
|
||||
)
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
OutputTokenDetails,
|
||||
UsageMetadata,
|
||||
add_usage,
|
||||
)
|
||||
|
||||
usage1 = UsageMetadata(
|
||||
input_tokens=1,
|
||||
output_tokens=2,
|
||||
total_tokens=3,
|
||||
)
|
||||
usage2 = UsageMetadata(
|
||||
input_tokens=4,
|
||||
output_tokens=5,
|
||||
total_tokens=9,
|
||||
)
|
||||
usage3 = UsageMetadata(
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
total_tokens=30,
|
||||
input_token_details=InputTokenDetails(audio=5),
|
||||
output_token_details=OutputTokenDetails(reasoning=10),
|
||||
)
|
||||
usage4 = UsageMetadata(
|
||||
input_tokens=5,
|
||||
output_tokens=10,
|
||||
total_tokens=15,
|
||||
input_token_details=InputTokenDetails(audio=3),
|
||||
output_token_details=OutputTokenDetails(reasoning=5),
|
||||
)
|
||||
messages = [
|
||||
AIMessage("Response 1", usage_metadata=usage1),
|
||||
AIMessage("Response 2", usage_metadata=usage2),
|
||||
AIMessage("Response 3", usage_metadata=usage3),
|
||||
AIMessage("Response 4", usage_metadata=usage4),
|
||||
]
|
||||
|
||||
|
||||
def test_usage_callback() -> None:
|
||||
llm = GenericFakeChatModel(messages=cycle(messages))
|
||||
|
||||
# Test context manager
|
||||
with get_usage_metadata_callback() as cb:
|
||||
_ = llm.invoke("Message 1")
|
||||
_ = llm.invoke("Message 2")
|
||||
total_1_2 = add_usage(usage1, usage2)
|
||||
assert cb.usage_metadata == total_1_2
|
||||
_ = llm.invoke("Message 3")
|
||||
_ = llm.invoke("Message 4")
|
||||
total_3_4 = add_usage(usage3, usage4)
|
||||
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
||||
|
||||
# Test via config
|
||||
callback = UsageMetadataCallbackHandler()
|
||||
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||
assert callback.usage_metadata == total_1_2
|
||||
|
||||
|
||||
async def test_usage_callback_async() -> None:
|
||||
llm = GenericFakeChatModel(messages=cycle(messages))
|
||||
|
||||
# Test context manager
|
||||
with get_usage_metadata_callback() as cb:
|
||||
_ = await llm.ainvoke("Message 1")
|
||||
_ = await llm.ainvoke("Message 2")
|
||||
total_1_2 = add_usage(usage1, usage2)
|
||||
assert cb.usage_metadata == total_1_2
|
||||
_ = await llm.ainvoke("Message 3")
|
||||
_ = await llm.ainvoke("Message 4")
|
||||
total_3_4 = add_usage(usage3, usage4)
|
||||
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
||||
|
||||
# Test via config
|
||||
callback = UsageMetadataCallbackHandler()
|
||||
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||
assert callback.usage_metadata == total_1_2
|
Loading…
Reference in New Issue
Block a user