From 7e62e3a137814b6813e11b602b2f78df1dec8d14 Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 25 Mar 2025 21:26:09 -0400 Subject: [PATCH] core[patch]: store model names on usage callback handler (#30487) So we avoid mingling tokens from different models. --- libs/core/langchain_core/callbacks/usage.py | 22 +++++--- .../callbacks/test_usage_callback.py | 55 ++++++++++++++++--- 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index d3ada3e1f3e..dd873adaebd 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -39,7 +39,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler): def __init__(self) -> None: super().__init__() self._lock = threading.Lock() - self.usage_metadata: Optional[UsageMetadata] = None + self.usage_metadata: dict[str, UsageMetadata] = {} def __repr__(self) -> str: return str(self.usage_metadata) @@ -51,21 +51,27 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler): generation = response.generations[0][0] except IndexError: generation = None + + usage_metadata = None + model_name = None if isinstance(generation, ChatGeneration): try: message = generation.message if isinstance(message, AIMessage): usage_metadata = message.usage_metadata - else: - usage_metadata = None + model_name = message.response_metadata.get("model_name") except AttributeError: - usage_metadata = None - else: - usage_metadata = None + pass # update shared state behind lock - with self._lock: - self.usage_metadata = add_usage(self.usage_metadata, usage_metadata) + if usage_metadata and model_name: + with self._lock: + if model_name not in self.usage_metadata: + self.usage_metadata[model_name] = usage_metadata + else: + self.usage_metadata[model_name] = add_usage( + self.usage_metadata[model_name], usage_metadata + ) @contextmanager diff --git a/libs/core/tests/unit_tests/callbacks/test_usage_callback.py b/libs/core/tests/unit_tests/callbacks/test_usage_callback.py index 80bf196e446..b583faecff0 100644 --- a/libs/core/tests/unit_tests/callbacks/test_usage_callback.py +++ b/libs/core/tests/unit_tests/callbacks/test_usage_callback.py @@ -1,4 +1,4 @@ -from itertools import cycle +from typing import Any from langchain_core.callbacks import ( UsageMetadataCallbackHandler, @@ -12,6 +12,7 @@ from langchain_core.messages.ai import ( UsageMetadata, add_usage, ) +from langchain_core.outputs import ChatResult usage1 = UsageMetadata( input_tokens=1, @@ -45,41 +46,77 @@ messages = [ ] +class FakeChatModelWithResponseMetadata(GenericFakeChatModel): + model_name: str + + def _generate(self, *args: Any, **kwargs: Any) -> ChatResult: + result = super()._generate(*args, **kwargs) + result.generations[0].message.response_metadata = { + "model_name": self.model_name + } + return result + + def test_usage_callback() -> None: - llm = GenericFakeChatModel(messages=cycle(messages)) + llm = FakeChatModelWithResponseMetadata( + messages=iter(messages), model_name="test_model" + ) # Test context manager with get_usage_metadata_callback() as cb: _ = llm.invoke("Message 1") _ = llm.invoke("Message 2") total_1_2 = add_usage(usage1, usage2) - assert cb.usage_metadata == total_1_2 + assert cb.usage_metadata == {"test_model": total_1_2} _ = llm.invoke("Message 3") _ = llm.invoke("Message 4") total_3_4 = add_usage(usage3, usage4) - assert cb.usage_metadata == add_usage(total_1_2, total_3_4) + assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)} # Test via config + llm = FakeChatModelWithResponseMetadata( + messages=iter(messages[:2]), model_name="test_model" + ) callback = UsageMetadataCallbackHandler() _ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]}) - assert callback.usage_metadata == total_1_2 + assert callback.usage_metadata == {"test_model": total_1_2} + + # Test multiple models + llm_1 = FakeChatModelWithResponseMetadata( + messages=iter(messages[:2]), model_name="test_model_1" + ) + llm_2 = FakeChatModelWithResponseMetadata( + messages=iter(messages[2:4]), model_name="test_model_2" + ) + callback = UsageMetadataCallbackHandler() + _ = llm_1.batch(["Message 1", "Message 2"], config={"callbacks": [callback]}) + _ = llm_2.batch(["Message 3", "Message 4"], config={"callbacks": [callback]}) + assert callback.usage_metadata == { + "test_model_1": total_1_2, + "test_model_2": total_3_4, + } async def test_usage_callback_async() -> None: - llm = GenericFakeChatModel(messages=cycle(messages)) + llm = FakeChatModelWithResponseMetadata( + messages=iter(messages), model_name="test_model" + ) # Test context manager with get_usage_metadata_callback() as cb: _ = await llm.ainvoke("Message 1") _ = await llm.ainvoke("Message 2") total_1_2 = add_usage(usage1, usage2) - assert cb.usage_metadata == total_1_2 + assert cb.usage_metadata == {"test_model": total_1_2} _ = await llm.ainvoke("Message 3") _ = await llm.ainvoke("Message 4") total_3_4 = add_usage(usage3, usage4) - assert cb.usage_metadata == add_usage(total_1_2, total_3_4) + assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)} # Test via config + llm = FakeChatModelWithResponseMetadata( + messages=iter(messages[:2]), model_name="test_model" + ) callback = UsageMetadataCallbackHandler() _ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]}) - assert callback.usage_metadata == total_1_2 + assert callback.usage_metadata == {"test_model": total_1_2}