mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
core[patch]: store model names on usage callback handler (#30487)
So we avoid mingling tokens from different models.
This commit is contained in:
parent
32827765bf
commit
7e62e3a137
@ -39,7 +39,7 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self.usage_metadata: Optional[UsageMetadata] = None
|
self.usage_metadata: dict[str, UsageMetadata] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return str(self.usage_metadata)
|
return str(self.usage_metadata)
|
||||||
@ -51,21 +51,27 @@ class UsageMetadataCallbackHandler(BaseCallbackHandler):
|
|||||||
generation = response.generations[0][0]
|
generation = response.generations[0][0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
generation = None
|
generation = None
|
||||||
|
|
||||||
|
usage_metadata = None
|
||||||
|
model_name = None
|
||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
try:
|
try:
|
||||||
message = generation.message
|
message = generation.message
|
||||||
if isinstance(message, AIMessage):
|
if isinstance(message, AIMessage):
|
||||||
usage_metadata = message.usage_metadata
|
usage_metadata = message.usage_metadata
|
||||||
else:
|
model_name = message.response_metadata.get("model_name")
|
||||||
usage_metadata = None
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
usage_metadata = None
|
pass
|
||||||
else:
|
|
||||||
usage_metadata = None
|
|
||||||
|
|
||||||
# update shared state behind lock
|
# update shared state behind lock
|
||||||
with self._lock:
|
if usage_metadata and model_name:
|
||||||
self.usage_metadata = add_usage(self.usage_metadata, usage_metadata)
|
with self._lock:
|
||||||
|
if model_name not in self.usage_metadata:
|
||||||
|
self.usage_metadata[model_name] = usage_metadata
|
||||||
|
else:
|
||||||
|
self.usage_metadata[model_name] = add_usage(
|
||||||
|
self.usage_metadata[model_name], usage_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from itertools import cycle
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
UsageMetadataCallbackHandler,
|
UsageMetadataCallbackHandler,
|
||||||
@ -12,6 +12,7 @@ from langchain_core.messages.ai import (
|
|||||||
UsageMetadata,
|
UsageMetadata,
|
||||||
add_usage,
|
add_usage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
|
|
||||||
usage1 = UsageMetadata(
|
usage1 = UsageMetadata(
|
||||||
input_tokens=1,
|
input_tokens=1,
|
||||||
@ -45,41 +46,77 @@ messages = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatModelWithResponseMetadata(GenericFakeChatModel):
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
def _generate(self, *args: Any, **kwargs: Any) -> ChatResult:
|
||||||
|
result = super()._generate(*args, **kwargs)
|
||||||
|
result.generations[0].message.response_metadata = {
|
||||||
|
"model_name": self.model_name
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def test_usage_callback() -> None:
|
def test_usage_callback() -> None:
|
||||||
llm = GenericFakeChatModel(messages=cycle(messages))
|
llm = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages), model_name="test_model"
|
||||||
|
)
|
||||||
|
|
||||||
# Test context manager
|
# Test context manager
|
||||||
with get_usage_metadata_callback() as cb:
|
with get_usage_metadata_callback() as cb:
|
||||||
_ = llm.invoke("Message 1")
|
_ = llm.invoke("Message 1")
|
||||||
_ = llm.invoke("Message 2")
|
_ = llm.invoke("Message 2")
|
||||||
total_1_2 = add_usage(usage1, usage2)
|
total_1_2 = add_usage(usage1, usage2)
|
||||||
assert cb.usage_metadata == total_1_2
|
assert cb.usage_metadata == {"test_model": total_1_2}
|
||||||
_ = llm.invoke("Message 3")
|
_ = llm.invoke("Message 3")
|
||||||
_ = llm.invoke("Message 4")
|
_ = llm.invoke("Message 4")
|
||||||
total_3_4 = add_usage(usage3, usage4)
|
total_3_4 = add_usage(usage3, usage4)
|
||||||
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
|
||||||
|
|
||||||
# Test via config
|
# Test via config
|
||||||
|
llm = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages[:2]), model_name="test_model"
|
||||||
|
)
|
||||||
callback = UsageMetadataCallbackHandler()
|
callback = UsageMetadataCallbackHandler()
|
||||||
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
_ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||||
assert callback.usage_metadata == total_1_2
|
assert callback.usage_metadata == {"test_model": total_1_2}
|
||||||
|
|
||||||
|
# Test multiple models
|
||||||
|
llm_1 = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages[:2]), model_name="test_model_1"
|
||||||
|
)
|
||||||
|
llm_2 = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages[2:4]), model_name="test_model_2"
|
||||||
|
)
|
||||||
|
callback = UsageMetadataCallbackHandler()
|
||||||
|
_ = llm_1.batch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||||
|
_ = llm_2.batch(["Message 3", "Message 4"], config={"callbacks": [callback]})
|
||||||
|
assert callback.usage_metadata == {
|
||||||
|
"test_model_1": total_1_2,
|
||||||
|
"test_model_2": total_3_4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_usage_callback_async() -> None:
|
async def test_usage_callback_async() -> None:
|
||||||
llm = GenericFakeChatModel(messages=cycle(messages))
|
llm = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages), model_name="test_model"
|
||||||
|
)
|
||||||
|
|
||||||
# Test context manager
|
# Test context manager
|
||||||
with get_usage_metadata_callback() as cb:
|
with get_usage_metadata_callback() as cb:
|
||||||
_ = await llm.ainvoke("Message 1")
|
_ = await llm.ainvoke("Message 1")
|
||||||
_ = await llm.ainvoke("Message 2")
|
_ = await llm.ainvoke("Message 2")
|
||||||
total_1_2 = add_usage(usage1, usage2)
|
total_1_2 = add_usage(usage1, usage2)
|
||||||
assert cb.usage_metadata == total_1_2
|
assert cb.usage_metadata == {"test_model": total_1_2}
|
||||||
_ = await llm.ainvoke("Message 3")
|
_ = await llm.ainvoke("Message 3")
|
||||||
_ = await llm.ainvoke("Message 4")
|
_ = await llm.ainvoke("Message 4")
|
||||||
total_3_4 = add_usage(usage3, usage4)
|
total_3_4 = add_usage(usage3, usage4)
|
||||||
assert cb.usage_metadata == add_usage(total_1_2, total_3_4)
|
assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)}
|
||||||
|
|
||||||
# Test via config
|
# Test via config
|
||||||
|
llm = FakeChatModelWithResponseMetadata(
|
||||||
|
messages=iter(messages[:2]), model_name="test_model"
|
||||||
|
)
|
||||||
callback = UsageMetadataCallbackHandler()
|
callback = UsageMetadataCallbackHandler()
|
||||||
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
_ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]})
|
||||||
assert callback.usage_metadata == total_1_2
|
assert callback.usage_metadata == {"test_model": total_1_2}
|
||||||
|
Loading…
Reference in New Issue
Block a user