standard-tests[patch]: require model_name in response_metadata if returns_usage_metadata (#30497)

We are implementing a token-counting callback handler in
`langchain-core` that is intended to work with all chat models
supporting usage metadata. The callback will aggregate usage metadata by
model. This requires responses to include the model name in its
metadata.

To support this, if a model `returns_usage_metadata`, we check that it
includes a string model name in its `response_metadata` in the
`"model_name"` key.

More context: https://github.com/langchain-ai/langchain/pull/30487
This commit is contained in:
ccurme 2025-03-26 12:20:53 -04:00 committed by GitHub
parent 20f82502e5
commit 22d1a7d7b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 75 additions and 12 deletions

View File

@ -247,6 +247,7 @@
" additional_kwargs={}, # Used to add additional payload to the message\n",
" response_metadata={ # Use for response metadata\n",
" \"time_in_seconds\": 3,\n",
" \"model_name\": self.model_name,\n",
" },\n",
" usage_metadata={\n",
" \"input_tokens\": ct_input_tokens,\n",
@ -309,7 +310,10 @@
"\n",
" # Let's add some other information (e.g., response metadata)\n",
" chunk = ChatGenerationChunk(\n",
" message=AIMessageChunk(content=\"\", response_metadata={\"time_in_sec\": 3})\n",
" message=AIMessageChunk(\n",
" content=\"\",\n",
" response_metadata={\"time_in_sec\": 3, \"model_name\": self.model_name},\n",
" )\n",
" )\n",
" if run_manager:\n",
" # This is optional in newer versions of LangChain\n",

View File

@ -329,6 +329,7 @@ class Chat__ModuleName__(BaseChatModel):
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
@ -391,7 +392,10 @@ class Chat__ModuleName__(BaseChatModel):
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain

View File

@ -471,6 +471,7 @@ class ChatFireworks(BaseChatModel):
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_name
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
@ -565,6 +566,7 @@ class ChatFireworks(BaseChatModel):
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_name
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs

View File

@ -98,16 +98,19 @@ async def test_astream() -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
if token.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"Expected exactly one chunk with token counts or response_metadata. "
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
"this is behaving properly."
)
assert isinstance(full, AIMessageChunk)
@ -118,6 +121,8 @@ async def test_astream() -> None:
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
assert isinstance(full.response_metadata["model_name"], str)
assert full.response_metadata["model_name"]
async def test_abatch() -> None:

View File

@ -236,13 +236,15 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
_delta = chunk["choices"][0]["delta"]
_choice = chunk["choices"][0]
_delta = _choice["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
@ -272,11 +274,14 @@ def _convert_chunk_to_message_chunk(
}
else:
usage_metadata = None
if _choice.get("finish_reason") is not None:
response_metadata["model_name"] = chunk.get("model")
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata=response_metadata,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)

View File

@ -20,7 +20,7 @@ def test_stream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()
for token in llm.stream("I'm Pickle Rick"):
for token in llm.stream("Hello"):
assert isinstance(token.content, str)
@ -30,16 +30,19 @@ async def test_astream() -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
chunks_with_response_metadata = 0
async for token in llm.astream("Hello"):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
if token.response_metadata:
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"Expected exactly one chunk with token counts or response_metadata. "
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
"this is behaving properly."
)
assert isinstance(full, AIMessageChunk)
@ -50,6 +53,8 @@ async def test_astream() -> None:
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)
assert isinstance(full.response_metadata["model_name"], str)
assert full.response_metadata["model_name"]
async def test_abatch() -> None:

View File

@ -337,6 +337,9 @@ class ChatModelIntegrationTests(ChatModelTests):
def returns_usage_metadata(self) -> bool:
return False
Models supporting ``usage_metadata`` should also return the name of the
underlying model in the ``response_metadata`` of the AIMessage.
.. dropdown:: supports_anthropic_inputs
Boolean property indicating whether the chat model supports Anthropic-style
@ -669,6 +672,11 @@ class ChatModelIntegrationTests(ChatModelTests):
This test is optional and should be skipped if the model does not return
usage metadata (see Configuration below).
.. versionchanged:: 0.3.17
Additionally check for the presence of `model_name` in the response
metadata, which is needed for usage tracking in callback handlers.
.. dropdown:: Configuration
By default, this test is run.
@ -739,6 +747,9 @@ class ChatModelIntegrationTests(ChatModelTests):
)
)]
)
Check also that the response includes a ``"model_name"`` key in its
``usage_metadata``.
"""
if not self.returns_usage_metadata:
pytest.skip("Not implemented.")
@ -750,6 +761,12 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)
# Check model_name is in response_metadata
# Needed for langchain_core.callbacks.usage
model_name = result.response_metadata.get("model_name")
assert isinstance(model_name, str)
assert model_name
if "audio_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_input()
assert msg.usage_metadata is not None
@ -809,6 +826,11 @@ class ChatModelIntegrationTests(ChatModelTests):
"""
Test to verify that the model returns correct usage metadata in streaming mode.
.. versionchanged:: 0.3.17
Additionally check for the presence of `model_name` in the response
metadata, which is needed for usage tracking in callback handlers.
.. dropdown:: Configuration
By default, this test is run.
@ -891,6 +913,9 @@ class ChatModelIntegrationTests(ChatModelTests):
)
)]
)
Check also that the aggregated response includes a ``"model_name"`` key
in its ``usage_metadata``.
"""
if not self.returns_usage_metadata:
pytest.skip("Not implemented.")
@ -915,6 +940,12 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(full.usage_metadata["output_tokens"], int)
assert isinstance(full.usage_metadata["total_tokens"], int)
# Check model_name is in response_metadata
# Needed for langchain_core.callbacks.usage
model_name = full.response_metadata.get("model_name")
assert isinstance(model_name, str)
assert model_name
if "audio_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_input(stream=True)
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]

View File

@ -412,6 +412,9 @@ class ChatModelUnitTests(ChatModelTests):
def returns_usage_metadata(self) -> bool:
return False
Models supporting ``usage_metadata`` should also return the name of the
underlying model in the ``response_metadata`` of the AIMessage.
.. dropdown:: supports_anthropic_inputs
Boolean property indicating whether the chat model supports Anthropic-style

View File

@ -76,6 +76,7 @@ class ChatParrotLink(BaseChatModel):
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
@ -138,7 +139,10 @@ class ChatParrotLink(BaseChatModel):
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
message=AIMessageChunk(
content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
)
)
if run_manager:
# This is optional in newer versions of LangChain