diff --git a/docs/docs/how_to/custom_chat_model.ipynb b/docs/docs/how_to/custom_chat_model.ipynb index 36ff587e11e..b8c8d7f0067 100644 --- a/docs/docs/how_to/custom_chat_model.ipynb +++ b/docs/docs/how_to/custom_chat_model.ipynb @@ -247,6 +247,7 @@ " additional_kwargs={}, # Used to add additional payload to the message\n", " response_metadata={ # Use for response metadata\n", " \"time_in_seconds\": 3,\n", + " \"model_name\": self.model_name,\n", " },\n", " usage_metadata={\n", " \"input_tokens\": ct_input_tokens,\n", @@ -309,7 +310,10 @@ "\n", " # Let's add some other information (e.g., response metadata)\n", " chunk = ChatGenerationChunk(\n", - " message=AIMessageChunk(content=\"\", response_metadata={\"time_in_sec\": 3})\n", + " message=AIMessageChunk(\n", + " content=\"\",\n", + " response_metadata={\"time_in_sec\": 3, \"model_name\": self.model_name},\n", + " )\n", " )\n", " if run_manager:\n", " # This is optional in newer versions of LangChain\n", diff --git a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py index 3de9b63179f..9703b50358a 100644 --- a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py +++ b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py @@ -329,6 +329,7 @@ class Chat__ModuleName__(BaseChatModel): additional_kwargs={}, # Used to add additional payload to the message response_metadata={ # Use for response metadata "time_in_seconds": 3, + "model_name": self.model_name, }, usage_metadata={ "input_tokens": ct_input_tokens, @@ -391,7 +392,10 @@ class Chat__ModuleName__(BaseChatModel): # Let's add some other information (e.g., response metadata) chunk = ChatGenerationChunk( - message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3}) + message=AIMessageChunk( + content="", + response_metadata={"time_in_sec": 3, "model_name": self.model_name}, + ) ) if run_manager: # This is optional in newer versions of LangChain diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 3f776456559..e2953eab7fe 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -471,6 +471,7 @@ class ChatFireworks(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -565,6 +566,7 @@ class ChatFireworks(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index ecaa2ebca8a..6a019bd38b7 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -98,16 +98,19 @@ async def test_astream() -> None: full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 async for token in llm.astream("I'm Pickle Rick"): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if token.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with token counts or response_metadata. " + "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) assert isinstance(full, AIMessageChunk) @@ -118,6 +121,8 @@ async def test_astream() -> None: full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) + assert isinstance(full.response_metadata["model_name"], str) + assert full.response_metadata["model_name"] async def test_abatch() -> None: diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index a7deb8b5471..7cdac2bcab3 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -236,13 +236,15 @@ async def acompletion_with_retry( def _convert_chunk_to_message_chunk( chunk: Dict, default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: - _delta = chunk["choices"][0]["delta"] + _choice = chunk["choices"][0] + _delta = _choice["delta"] role = _delta.get("role") content = _delta.get("content") or "" if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: additional_kwargs: Dict = {} + response_metadata = {} if raw_tool_calls := _delta.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls try: @@ -272,11 +274,14 @@ def _convert_chunk_to_message_chunk( } else: usage_metadata = None + if _choice.get("finish_reason") is not None: + response_metadata["model_name"] = chunk.get("model") return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata=response_metadata, ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index d3592ef32fd..8bec346d29a 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -20,7 +20,7 @@ def test_stream() -> None: """Test streaming tokens from ChatMistralAI.""" llm = ChatMistralAI() - for token in llm.stream("I'm Pickle Rick"): + for token in llm.stream("Hello"): assert isinstance(token.content, str) @@ -30,16 +30,19 @@ async def test_astream() -> None: full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 - async for token in llm.astream("I'm Pickle Rick"): + chunks_with_response_metadata = 0 + async for token in llm.astream("Hello"): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if token.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with token counts or response_metadata. " + "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) assert isinstance(full, AIMessageChunk) @@ -50,6 +53,8 @@ async def test_astream() -> None: full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) + assert isinstance(full.response_metadata["model_name"], str) + assert full.response_metadata["model_name"] async def test_abatch() -> None: diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 7041f0c9f38..b5e294ffc8d 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -337,6 +337,9 @@ class ChatModelIntegrationTests(ChatModelTests): def returns_usage_metadata(self) -> bool: return False + Models supporting ``usage_metadata`` should also return the name of the + underlying model in the ``response_metadata`` of the AIMessage. + .. dropdown:: supports_anthropic_inputs Boolean property indicating whether the chat model supports Anthropic-style @@ -669,6 +672,11 @@ class ChatModelIntegrationTests(ChatModelTests): This test is optional and should be skipped if the model does not return usage metadata (see Configuration below). + .. versionchanged:: 0.3.17 + + Additionally check for the presence of `model_name` in the response + metadata, which is needed for usage tracking in callback handlers. + .. dropdown:: Configuration By default, this test is run. @@ -739,6 +747,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) )] ) + + Check also that the response includes a ``"model_name"`` key in its + ``usage_metadata``. """ if not self.returns_usage_metadata: pytest.skip("Not implemented.") @@ -750,6 +761,12 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int) + # Check model_name is in response_metadata + # Needed for langchain_core.callbacks.usage + model_name = result.response_metadata.get("model_name") + assert isinstance(model_name, str) + assert model_name + if "audio_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_audio_input() assert msg.usage_metadata is not None @@ -809,6 +826,11 @@ class ChatModelIntegrationTests(ChatModelTests): """ Test to verify that the model returns correct usage metadata in streaming mode. + .. versionchanged:: 0.3.17 + + Additionally check for the presence of `model_name` in the response + metadata, which is needed for usage tracking in callback handlers. + .. dropdown:: Configuration By default, this test is run. @@ -891,6 +913,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) )] ) + + Check also that the aggregated response includes a ``"model_name"`` key + in its ``usage_metadata``. """ if not self.returns_usage_metadata: pytest.skip("Not implemented.") @@ -915,6 +940,12 @@ class ChatModelIntegrationTests(ChatModelTests): assert isinstance(full.usage_metadata["output_tokens"], int) assert isinstance(full.usage_metadata["total_tokens"], int) + # Check model_name is in response_metadata + # Needed for langchain_core.callbacks.usage + model_name = full.response_metadata.get("model_name") + assert isinstance(model_name, str) + assert model_name + if "audio_input" in self.supported_usage_metadata_details["stream"]: msg = self.invoke_with_audio_input(stream=True) assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index] diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py index a470a9b59d5..beec0b98cb1 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py @@ -412,6 +412,9 @@ class ChatModelUnitTests(ChatModelTests): def returns_usage_metadata(self) -> bool: return False + Models supporting ``usage_metadata`` should also return the name of the + underlying model in the ``response_metadata`` of the AIMessage. + .. dropdown:: supports_anthropic_inputs Boolean property indicating whether the chat model supports Anthropic-style diff --git a/libs/standard-tests/tests/unit_tests/custom_chat_model.py b/libs/standard-tests/tests/unit_tests/custom_chat_model.py index 1791138cf35..30135883469 100644 --- a/libs/standard-tests/tests/unit_tests/custom_chat_model.py +++ b/libs/standard-tests/tests/unit_tests/custom_chat_model.py @@ -76,6 +76,7 @@ class ChatParrotLink(BaseChatModel): additional_kwargs={}, # Used to add additional payload to the message response_metadata={ # Use for response metadata "time_in_seconds": 3, + "model_name": self.model_name, }, usage_metadata={ "input_tokens": ct_input_tokens, @@ -138,7 +139,10 @@ class ChatParrotLink(BaseChatModel): # Let's add some other information (e.g., response metadata) chunk = ChatGenerationChunk( - message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3}) + message=AIMessageChunk( + content="", + response_metadata={"time_in_sec": 3, "model_name": self.model_name}, + ) ) if run_manager: # This is optional in newer versions of LangChain