diff --git a/libs/core/langchain_core/callbacks/__init__.py b/libs/core/langchain_core/callbacks/__init__.py index a55f41ce344..eca2b14feb5 100644 --- a/libs/core/langchain_core/callbacks/__init__.py +++ b/libs/core/langchain_core/callbacks/__init__.py @@ -43,6 +43,10 @@ from langchain_core.callbacks.manager import ( ) from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_core.callbacks.usage import ( + UsageMetadataCallbackHandler, + get_usage_metadata_callback, +) __all__ = [ "dispatch_custom_event", @@ -77,4 +81,6 @@ __all__ = [ "StdOutCallbackHandler", "StreamingStdOutCallbackHandler", "FileCallbackHandler", + "UsageMetadataCallbackHandler", + "get_usage_metadata_callback", ] diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py new file mode 100644 index 00000000000..519b857bb1f --- /dev/null +++ b/libs/core/langchain_core/callbacks/usage.py @@ -0,0 +1,108 @@ +"""Callback Handler that tracks AIMessage.usage_metadata.""" + +import threading +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Optional + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage +from langchain_core.messages.ai import UsageMetadata, add_usage +from langchain_core.outputs import ChatGeneration, LLMResult + + +class UsageMetadataCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks AIMessage.usage_metadata. + + Example: + .. code-block:: python + + from langchain.chat_models import init_chat_model + from langchain_core.callbacks import UsageMetadataCallbackHandler + + llm = init_chat_model(model="openai:gpt-4o-mini") + + callback = UsageMetadataCallbackHandler() + results = llm.batch(["Hello", "Goodbye"], config={"callbacks": [callback]}) + print(callback.usage_metadata) + + .. code-block:: none + + {'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}} + + .. versionadded:: 0.3.49 + """ # noqa: E501 + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + self.usage_metadata: Optional[UsageMetadata] = None + + def __repr__(self) -> str: + return str(self.usage_metadata) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + # Check for usage_metadata (langchain-core >= 0.2.2) + try: + generation = response.generations[0][0] + except IndexError: + generation = None + if isinstance(generation, ChatGeneration): + try: + message = generation.message + if isinstance(message, AIMessage): + usage_metadata = message.usage_metadata + else: + usage_metadata = None + except AttributeError: + usage_metadata = None + else: + usage_metadata = None + + # update shared state behind lock + with self._lock: + self.usage_metadata = add_usage(self.usage_metadata, usage_metadata) + + +@contextmanager +def get_usage_metadata_callback( + name: str = "usage_metadata_callback", +) -> Generator[UsageMetadataCallbackHandler, None, None]: + """Get context manager for tracking usage metadata across chat model calls using + ``AIMessage.usage_metadata``. + + Args: + name (str): The name of the context variable. Defaults to + ``"usage_metadata_callback"``. + + Example: + .. code-block:: python + + from langchain.chat_models import init_chat_model + from langchain_core.callbacks import get_usage_metadata_callback + + llm = init_chat_model(model="openai:gpt-4o-mini") + + with get_usage_metadata_callback() as cb: + llm.invoke("Hello") + llm.invoke("Goodbye") + print(cb.usage_metadata) + + .. code-block:: none + + {'output_token_details': {'audio': 0, 'reasoning': 0}, 'input_tokens': 17, 'output_tokens': 31, 'total_tokens': 48, 'input_token_details': {'cache_read': 0, 'audio': 0}} + + .. versionadded:: 0.3.49 + """ # noqa: E501 + from langchain_core.tracers.context import register_configure_hook + + usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = ( + ContextVar(name, default=None) + ) + register_configure_hook(usage_metadata_callback_var, True) + cb = UsageMetadataCallbackHandler() + usage_metadata_callback_var.set(cb) + yield cb + usage_metadata_callback_var.set(None) diff --git a/libs/core/tests/unit_tests/callbacks/test_imports.py b/libs/core/tests/unit_tests/callbacks/test_imports.py index f10f8323253..441b899f3ba 100644 --- a/libs/core/tests/unit_tests/callbacks/test_imports.py +++ b/libs/core/tests/unit_tests/callbacks/test_imports.py @@ -33,6 +33,8 @@ EXPECTED_ALL = [ "FileCallbackHandler", "adispatch_custom_event", "dispatch_custom_event", + "UsageMetadataCallbackHandler", + "get_usage_metadata_callback", ] diff --git a/libs/core/tests/unit_tests/callbacks/test_usage_callback.py b/libs/core/tests/unit_tests/callbacks/test_usage_callback.py new file mode 100644 index 00000000000..80bf196e446 --- /dev/null +++ b/libs/core/tests/unit_tests/callbacks/test_usage_callback.py @@ -0,0 +1,85 @@ +from itertools import cycle + +from langchain_core.callbacks import ( + UsageMetadataCallbackHandler, + get_usage_metadata_callback, +) +from langchain_core.language_models import GenericFakeChatModel +from langchain_core.messages import AIMessage +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, + add_usage, +) + +usage1 = UsageMetadata( + input_tokens=1, + output_tokens=2, + total_tokens=3, +) +usage2 = UsageMetadata( + input_tokens=4, + output_tokens=5, + total_tokens=9, +) +usage3 = UsageMetadata( + input_tokens=10, + output_tokens=20, + total_tokens=30, + input_token_details=InputTokenDetails(audio=5), + output_token_details=OutputTokenDetails(reasoning=10), +) +usage4 = UsageMetadata( + input_tokens=5, + output_tokens=10, + total_tokens=15, + input_token_details=InputTokenDetails(audio=3), + output_token_details=OutputTokenDetails(reasoning=5), +) +messages = [ + AIMessage("Response 1", usage_metadata=usage1), + AIMessage("Response 2", usage_metadata=usage2), + AIMessage("Response 3", usage_metadata=usage3), + AIMessage("Response 4", usage_metadata=usage4), +] + + +def test_usage_callback() -> None: + llm = GenericFakeChatModel(messages=cycle(messages)) + + # Test context manager + with get_usage_metadata_callback() as cb: + _ = llm.invoke("Message 1") + _ = llm.invoke("Message 2") + total_1_2 = add_usage(usage1, usage2) + assert cb.usage_metadata == total_1_2 + _ = llm.invoke("Message 3") + _ = llm.invoke("Message 4") + total_3_4 = add_usage(usage3, usage4) + assert cb.usage_metadata == add_usage(total_1_2, total_3_4) + + # Test via config + callback = UsageMetadataCallbackHandler() + _ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]}) + assert callback.usage_metadata == total_1_2 + + +async def test_usage_callback_async() -> None: + llm = GenericFakeChatModel(messages=cycle(messages)) + + # Test context manager + with get_usage_metadata_callback() as cb: + _ = await llm.ainvoke("Message 1") + _ = await llm.ainvoke("Message 2") + total_1_2 = add_usage(usage1, usage2) + assert cb.usage_metadata == total_1_2 + _ = await llm.ainvoke("Message 3") + _ = await llm.ainvoke("Message 4") + total_3_4 = add_usage(usage3, usage4) + assert cb.usage_metadata == add_usage(total_1_2, total_3_4) + + # Test via config + callback = UsageMetadataCallbackHandler() + _ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]}) + assert callback.usage_metadata == total_1_2