From 22d1a7d7b6d6be0655bb791a6a65b60150c9afc5 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 26 Mar 2025 12:20:53 -0400 Subject: [PATCH] standard-tests[patch]: require model_name in response_metadata if returns_usage_metadata (#30497) We are implementing a token-counting callback handler in `langchain-core` that is intended to work with all chat models supporting usage metadata. The callback will aggregate usage metadata by model. This requires responses to include the model name in its metadata. To support this, if a model `returns_usage_metadata`, we check that it includes a string model name in its `response_metadata` in the `"model_name"` key. More context: https://github.com/langchain-ai/langchain/pull/30487 --- docs/docs/how_to/custom_chat_model.ipynb | 6 +++- .../integration_template/chat_models.py | 6 +++- .../langchain_fireworks/chat_models.py | 2 ++ .../integration_tests/test_chat_models.py | 11 +++++-- .../langchain_mistralai/chat_models.py | 7 ++++- .../integration_tests/test_chat_models.py | 15 ++++++--- .../integration_tests/chat_models.py | 31 +++++++++++++++++++ .../langchain_tests/unit_tests/chat_models.py | 3 ++ .../tests/unit_tests/custom_chat_model.py | 6 +++- 9 files changed, 75 insertions(+), 12 deletions(-) diff --git a/docs/docs/how_to/custom_chat_model.ipynb b/docs/docs/how_to/custom_chat_model.ipynb index 36ff587e11e..b8c8d7f0067 100644 --- a/docs/docs/how_to/custom_chat_model.ipynb +++ b/docs/docs/how_to/custom_chat_model.ipynb @@ -247,6 +247,7 @@ " additional_kwargs={}, # Used to add additional payload to the message\n", " response_metadata={ # Use for response metadata\n", " \"time_in_seconds\": 3,\n", + " \"model_name\": self.model_name,\n", " },\n", " usage_metadata={\n", " \"input_tokens\": ct_input_tokens,\n", @@ -309,7 +310,10 @@ "\n", " # Let's add some other information (e.g., response metadata)\n", " chunk = ChatGenerationChunk(\n", - " message=AIMessageChunk(content=\"\", response_metadata={\"time_in_sec\": 3})\n", + " message=AIMessageChunk(\n", + " content=\"\",\n", + " response_metadata={\"time_in_sec\": 3, \"model_name\": self.model_name},\n", + " )\n", " )\n", " if run_manager:\n", " # This is optional in newer versions of LangChain\n", diff --git a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py index 3de9b63179f..9703b50358a 100644 --- a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py +++ b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py @@ -329,6 +329,7 @@ class Chat__ModuleName__(BaseChatModel): additional_kwargs={}, # Used to add additional payload to the message response_metadata={ # Use for response metadata "time_in_seconds": 3, + "model_name": self.model_name, }, usage_metadata={ "input_tokens": ct_input_tokens, @@ -391,7 +392,10 @@ class Chat__ModuleName__(BaseChatModel): # Let's add some other information (e.g., response metadata) chunk = ChatGenerationChunk( - message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3}) + message=AIMessageChunk( + content="", + response_metadata={"time_in_sec": 3, "model_name": self.model_name}, + ) ) if run_manager: # This is optional in newer versions of LangChain diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 3f776456559..e2953eab7fe 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -471,6 +471,7 @@ class ChatFireworks(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -565,6 +566,7 @@ class ChatFireworks(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index ecaa2ebca8a..6a019bd38b7 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -98,16 +98,19 @@ async def test_astream() -> None: full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if token.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with token counts or response_metadata. " + "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) assert isinstance(full, AIMessageChunk) @@ -118,6 +121,8 @@ async def test_astream() -> None: full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) + assert isinstance(full.response_metadata["model_name"], str) + assert full.response_metadata["model_name"] async def test_abatch() -> None: diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index a7deb8b5471..7cdac2bcab3 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -236,13 +236,15 @@ async def acompletion_with_retry( def _convert_chunk_to_message_chunk( chunk: Dict, default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: - _delta = chunk["choices"][0]["delta"] + _choice = chunk["choices"][0] + _delta = _choice["delta"] role = _delta.get("role") content = _delta.get("content") or "" if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: additional_kwargs: Dict = {} + response_metadata = {} if raw_tool_calls := _delta.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls try: @@ -272,11 +274,14 @@ def _convert_chunk_to_message_chunk( } else: usage_metadata = None + if _choice.get("finish_reason") is not None: + response_metadata["model_name"] = chunk.get("model") return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata=response_metadata, ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index d3592ef32fd..8bec346d29a 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -20,7 +20,7 @@ def test_stream() -> None: """Test streaming tokens from ChatMistralAI.""" llm = ChatMistralAI() - for token in llm.stream("I'm Pickle Rick"): + for token in llm.stream("Hello"): assert isinstance(token.content, str) @@ -30,16 +30,19 @@ async def test_astream() -> None: full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 - async for token in llm.astream("I'm Pickle Rick"): + chunks_with_response_metadata = 0 + async for token in llm.astream("Hello"): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if token.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with token counts or response_metadata. " + "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) assert isinstance(full, AIMessageChunk) @@ -50,6 +53,8 @@ async def test_astream() -> None: full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) + assert isinstance(full.response_metadata["model_name"], str) + assert full.response_metadata["model_name"] async def test_abatch() -> None: diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 7041f0c9f38..b5e294ffc8d 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -337,6 +337,9 @@ class ChatModelIntegrationTests(ChatModelTests): def returns_usage_metadata(self) -> bool: return False + Models supporting ``usage_metadata`` should also return the name of the + underlying model in the ``response_metadata`` of the AIMessage. + .. dropdown:: supports_anthropic_inputs Boolean property indicating whether the chat model supports Anthropic-style @@ -669,6 +672,11 @@ class ChatModelIntegrationTests(ChatModelTests): This test is optional and should be skipped if the model does not return usage metadata (see Configuration below). + .. versionchanged:: 0.3.17 + + Additionally check for the presence of `model_name` in the response + metadata, which is needed for usage tracking in callback handlers. + .. dropdown:: Configuration By default, this test is run. @@ -739,6 +747,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) )] ) + + Check also that the response includes a ``"model_name"`` key in its + ``usage_metadata``. """ if not self.returns_usage_metadata: pytest.skip("Not implemented.") @@ -750,6 +761,12 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int) + # Check model_name is in response_metadata + # Needed for langchain_core.callbacks.usage + model_name = result.response_metadata.get("model_name") + assert isinstance(model_name, str) + assert model_name + if "audio_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_audio_input() assert msg.usage_metadata is not None @@ -809,6 +826,11 @@ class ChatModelIntegrationTests(ChatModelTests): """ Test to verify that the model returns correct usage metadata in streaming mode. + .. versionchanged:: 0.3.17 + + Additionally check for the presence of `model_name` in the response + metadata, which is needed for usage tracking in callback handlers. + .. dropdown:: Configuration By default, this test is run. @@ -891,6 +913,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) )] ) + + Check also that the aggregated response includes a ``"model_name"`` key + in its ``usage_metadata``. """ if not self.returns_usage_metadata: pytest.skip("Not implemented.") @@ -915,6 +940,12 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(full.usage_metadata["output_tokens"], int) assert isinstance(full.usage_metadata["total_tokens"], int) + # Check model_name is in response_metadata + # Needed for langchain_core.callbacks.usage + model_name = full.response_metadata.get("model_name") + assert isinstance(model_name, str) + assert model_name + if "audio_input" in self.supported_usage_metadata_details["stream"]: msg = self.invoke_with_audio_input(stream=True) assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index] diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py index a470a9b59d5..beec0b98cb1 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py @@ -412,6 +412,9 @@ class ChatModelUnitTests(ChatModelTests): def returns_usage_metadata(self) -> bool: return False + Models supporting ``usage_metadata`` should also return the name of the + underlying model in the ``response_metadata`` of the AIMessage. + .. dropdown:: supports_anthropic_inputs Boolean property indicating whether the chat model supports Anthropic-style diff --git a/libs/standard-tests/tests/unit_tests/custom_chat_model.py b/libs/standard-tests/tests/unit_tests/custom_chat_model.py index 1791138cf35..30135883469 100644 --- a/libs/standard-tests/tests/unit_tests/custom_chat_model.py +++ b/libs/standard-tests/tests/unit_tests/custom_chat_model.py @@ -76,6 +76,7 @@ class ChatParrotLink(BaseChatModel): additional_kwargs={}, # Used to add additional payload to the message response_metadata={ # Use for response metadata "time_in_seconds": 3, + "model_name": self.model_name, }, usage_metadata={ "input_tokens": ct_input_tokens, @@ -138,7 +139,10 @@ class ChatParrotLink(BaseChatModel): # Let's add some other information (e.g., response metadata) chunk = ChatGenerationChunk( - message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3}) + message=AIMessageChunk( + content="", + response_metadata={"time_in_sec": 3, "model_name": self.model_name}, + ) ) if run_manager: # This is optional in newer versions of LangChain