mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 03:15:11 +00:00
core[patch]: add token counting callback handler (#30481)
Stripped-down version of [OpenAICallbackHandler](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py) that just tracks `AIMessage.usage_metadata`. ```python from langchain_core.callbacks import get_usage_metadata_callback from langgraph.prebuilt import create_react_agent def get_weather(location: str) -> str: """Get the weather at a location.""" return "It's sunny." tools = [get_weather] agent = create_react_agent("openai:gpt-4o-mini", tools) with get_usage_metadata_callback() as cb: result = await agent.ainvoke({"messages": "What's the weather in Boston?"}) print(cb.usage_metadata) ```
This commit is contained in:
parent
1d2b1d8e5e
commit
4b9e2e51f3
@ -43,6 +43,10 @@ from langchain_core.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
from langchain_core.callbacks.usage import (
|
||||||
|
UsageMetadataCallbackHandler,
|
||||||
|
get_usage_metadata_callback,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"dispatch_custom_event",
|
"dispatch_custom_event",
|
||||||
@ -77,4 +81,6 @@ __all__ = [
|
|||||||
"StdOutCallbackHandler",
|
"StdOutCallbackHandler",
|
||||||
"StreamingStdOutCallbackHandler",
|
"StreamingStdOutCallbackHandler",
|
||||||
"FileCallbackHandler",
|
"FileCallbackHandler",
|
||||||
|
"UsageMetadataCallbackHandler",
|
||||||
|
"get_usage_metadata_callback",
|
||||||
]
|
]
|
||||||
|
108
libs/core/langchain_core/callbacks/usage.py
Normal file
108
libs/core/langchain_core/callbacks/usage.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""Callback Handler that tracks AIMessage.usage_metadata."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||||
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that tracks AIMessage.usage_metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.callbacks import UsageMetadataCallbackHandler
|
||||||
|
|
||||||
|
llm = init_chat_model(model="openai:gpt-4o-mini")
|
||||||
|
|
||||||
|
callback = UsageMetadataCallbackHandler()
|
||||||
|
results = llm.batch(["Hello", "Goodbye"], config={"callbacks": [callback]})
|
||||||
|
print(callback.usage_metadata)
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
|
||||||
|
|
||||||
|
.. versionadded:: 0.3.49
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self.usage_metadata: Optional[UsageMetadata] = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return str(self.usage_metadata)
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Collect token usage."""
|
||||||
|
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||||
|
try:
|
||||||
|
generation = response.generations[0][0]
|
||||||
|
except IndexError:
|
||||||
|
generation = None
|
||||||
|
if isinstance(generation, ChatGeneration):
|
||||||
|
try:
|
||||||
|
message = generation.message
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
usage_metadata = message.usage_metadata
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
except AttributeError:
|
||||||
|
usage_metadata = None
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
|
# update shared state behind lock
|
||||||
|
with self._lock:
|
||||||
|
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_usage_metadata_callback(
|
||||||
|
name: str = "usage_metadata_callback",
|
||||||
|
) -> Generator[UsageMetadataCallbackHandler, None, None]:
|
||||||
|
"""Get context manager for tracking usage metadata across chat model calls using
|
||||||
|
``AIMessage.usage_metadata``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the context variable. Defaults to
|
||||||
|
``"usage_metadata_callback"``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.callbacks import get_usage_metadata_callback
|
||||||
|
|
||||||
|
llm = init_chat_model(model="openai:gpt-4o-mini")
|
||||||
|
|
||||||
|
with get_usage_metadata_callback() as cb:
|
||||||
|
llm.invoke("Hello")
|
||||||
|
llm.invoke("Goodbye")
|
||||||
|
print(cb.usage_metadata)
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
{'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}}
|
||||||
|
|
||||||
|
.. versionadded:: 0.3.49
|
||||||
|
""" # noqa: E501
|
||||||
|
from langchain_core.tracers.context import register_configure_hook
|
||||||
|
|
||||||
|
usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = (
|
||||||
|
ContextVar(name, default=None)
|
||||||
|
)
|
||||||
|
register_configure_hook(usage_metadata_callback_var, True)
|
||||||
|
cb = UsageMetadataCallbackHandler()
|
||||||
|
usage_metadata_callback_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
usage_metadata_callback_var.set(None)
|
@ -33,6 +33,8 @@ EXPECTED_ALL = [
|
|||||||
"FileCallbackHandler",
|
"FileCallbackHandler",
|
||||||
"adispatch_custom_event",
|
"adispatch_custom_event",
|
||||||
"dispatch_custom_event",
|
"dispatch_custom_event",
|
||||||
|
"UsageMetadataCallbackHandler",
|
||||||
|
"get_usage_metadata_callback",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
85
libs/core/tests/unit_tests/callbacks/test_usage_callback.py
Normal file
85
libs/core/tests/unit_tests/callbacks/test_usage_callback.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
UsageMetadataCallbackHandler,
|
||||||
|
get_usage_metadata_callback,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import GenericFakeChatModel
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.messages.ai import (
|
||||||
|
InputTokenDetails,
|
||||||
|
OutputTokenDetails,
|
||||||
|
UsageMetadata,
|
||||||
|
add_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
usage1 = UsageMetadata(
|
||||||
|
input_tokens=1,
|
||||||
|
output_tokens=2,
|
||||||
|
total_tokens=3,
|
||||||
|
)
|
||||||
|
usage2 = UsageMetadata(
|
||||||
|
input_tokens=4,
|
||||||
|
output_tokens=5,
|
||||||
|
total_tokens=9,
|
||||||
|
)
|
||||||
|
usage3 = UsageMetadata(
|
||||||
|
input_tokens=10,
|
||||||
|
output_tokens=20,
|
||||||
|
total_tokens=30,
|
||||||
|
input_token_details=InputTokenDetails(audio=5),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=10),
|
||||||
|
)
|
||||||
|
usage4 = UsageMetadata(
|
||||||
|
input_tokens=5,
|
||||||
|
output_tokens=10,
|
||||||
|
total_tokens=15,
|
||||||
|
input_token_details=InputTokenDetails(audio=3),
|
||||||
|
output_token_details=OutputTokenDetails(reasoning=5),
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
AIMessage("Response 1", usage_metadata=usage1),
|
||||||
|
AIMessage("Response 2", usage_metadata=usage2),
|
||||||
|
AIMessage("Response 3", usage_metadata=usage3),
|
||||||
|
AIMessage("Response 4", usage_metadata=usage4),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_callback() -> None:
|
||||||
|
llm = GenericFakeChatModel(messages=cycle(messages))
|
||||||
|
|
||||||
|
# Test context manager
|
||||||
|
with get_usage_metadata_callback() as cb:
|
||||||
|
_ = llm.invoke("Message 1")
|
||||||
|
_ = llm.invoke("Message 2")
|
||||||
|
total_1_2 = add_usage(usage1, usage2)
|
||||||
|
assert cb.usage_metadata == total_1_2
|
||||||
|
_ = llm.invoke("Message 3")
|
||||||
|
_ = llm.invoke("Message 4")
|
||||||
|
total_3_4 = add_usage(usage3, usage4)
|
||||||
|
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
||||||
|
|
||||||
|
# Test via config
|
||||||
|
callback = UsageMetadataCallbackHandler()
|
||||||
|
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||||
|
assert callback.usage_metadata == total_1_2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_usage_callback_async() -> None:
|
||||||
|
llm = GenericFakeChatModel(messages=cycle(messages))
|
||||||
|
|
||||||
|
# Test context manager
|
||||||
|
with get_usage_metadata_callback() as cb:
|
||||||
|
_ = await llm.ainvoke("Message 1")
|
||||||
|
_ = await llm.ainvoke("Message 2")
|
||||||
|
total_1_2 = add_usage(usage1, usage2)
|
||||||
|
assert cb.usage_metadata == total_1_2
|
||||||
|
_ = await llm.ainvoke("Message 3")
|
||||||
|
_ = await llm.ainvoke("Message 4")
|
||||||
|
total_3_4 = add_usage(usage3, usage4)
|
||||||
|
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
||||||
|
|
||||||
|
# Test via config
|
||||||
|
callback = UsageMetadataCallbackHandler()
|
||||||
|
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||||
|
assert callback.usage_metadata == total_1_2
|
Loading…
Reference in New Issue
Block a user