core[patch]: store model names on usage callback handler (#30487)

So we avoid mingling tokens from different models.
This commit is contained in:
ccurme 2025-03-25 21:26:09 -04:00 committed by GitHub
parent 32827765bf
commit 7e62e3a137
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 17 deletions

View File

@ -39,7 +39,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self.usage_metadata: Optional[UsageMetadata] = None
self.usage_metadata: dict[str, UsageMetadata] = {}
def __repr__(self) -> str:
return str(self.usage_metadata)
@ -51,21 +51,27 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
generation = response.generations[0][0]
except IndexError:
generation = None
usage_metadata = None
model_name = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
else:
usage_metadata = None
model_name = message.response_metadata.get("model_name")
except AttributeError:
usage_metadata = None
else:
usage_metadata = None
pass
# update shared state behind lock
with self._lock:
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
if usage_metadata and model_name:
with self._lock:
if model_name not in self.usage_metadata:
self.usage_metadata[model_name] = usage_metadata
else:
self.usage_metadata[model_name] = add_usage(
self.usage_metadata[model_name], usage_metadata
)
@contextmanager

View File

@ -1,4 +1,4 @@
from itertools import cycle
from typing import Any
from langchain_core.callbacks import (
UsageMetadataCallbackHandler,
@ -12,6 +12,7 @@ from langchain_core.messages.ai import (
UsageMetadata,
add_usage,
)
from langchain_core.outputs import ChatResult
usage1 = UsageMetadata(
input_tokens=1,
@ -45,41 +46,77 @@ messages = [
]
class FakeChatModelWithResponseMetadata(GenericFakeChatModel):
model_name: str
def _generate(self, *args: Any, **kwargs: Any) -> ChatResult:
result = super()._generate(*args, **kwargs)
result.generations[0].message.response_metadata = {
"model_name": self.model_name
}
return result
def test_usage_callback() -> None:
llm = GenericFakeChatModel(messages=cycle(messages))
llm = FakeChatModelWithResponseMetadata(
messages=iter(messages), model_name="test_model"
)
# Test context manager
with get_usage_metadata_callback() as cb:
_ = llm.invoke("Message 1")
_ = llm.invoke("Message 2")
total_1_2 = add_usage(usage1, usage2)
assert cb.usage_metadata == total_1_2
assert cb.usage_metadata == {"test_model": total_1_2}
_ = llm.invoke("Message 3")
_ = llm.invoke("Message 4")
total_3_4 = add_usage(usage3, usage4)
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
# Test via config
llm = FakeChatModelWithResponseMetadata(
messages=iter(messages[:2]), model_name="test_model"
)
callback = UsageMetadataCallbackHandler()
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
assert callback.usage_metadata == total_1_2
assert callback.usage_metadata == {"test_model": total_1_2}
# Test multiple models
llm_1 = FakeChatModelWithResponseMetadata(
messages=iter(messages[:2]), model_name="test_model_1"
)
llm_2 = FakeChatModelWithResponseMetadata(
messages=iter(messages[2:4]), model_name="test_model_2"
)
callback = UsageMetadataCallbackHandler()
_ = llm_1.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
_ = llm_2.batch(["Message 3", "Message 4"], config={"callbacks": [callback]})
assert callback.usage_metadata == {
"test_model_1": total_1_2,
"test_model_2": total_3_4,
}
async def test_usage_callback_async() -> None:
llm = GenericFakeChatModel(messages=cycle(messages))
llm = FakeChatModelWithResponseMetadata(
messages=iter(messages), model_name="test_model"
)
# Test context manager
with get_usage_metadata_callback() as cb:
_ = await llm.ainvoke("Message 1")
_ = await llm.ainvoke("Message 2")
total_1_2 = add_usage(usage1, usage2)
assert cb.usage_metadata == total_1_2
assert cb.usage_metadata == {"test_model": total_1_2}
_ = await llm.ainvoke("Message 3")
_ = await llm.ainvoke("Message 4")
total_3_4 = add_usage(usage3, usage4)
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
# Test via config
llm = FakeChatModelWithResponseMetadata(
messages=iter(messages[:2]), model_name="test_model"
)
callback = UsageMetadataCallbackHandler()
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
assert callback.usage_metadata == total_1_2
assert callback.usage_metadata == {"test_model": total_1_2}