diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index dc8b313fde6..7c346c24eab 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -615,6 +615,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generation.message.response_metadata = _gen_info_and_msg_metadata( generation ) + if len(result.generations) == 1 and result.llm_output is not None: + result.generations[0].message.response_metadata = { + **result.llm_output, + **result.generations[0].message.response_metadata, + } if check_cache and llm_cache: llm_cache.update(prompt, llm_string, result.generations) return result @@ -651,6 +656,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generation.message.response_metadata = _gen_info_and_msg_metadata( generation ) + if len(result.generations) == 1 and result.llm_output is not None: + result.generations[0].message.response_metadata = { + **result.llm_output, + **result.generations[0].message.response_metadata, + } if check_cache and llm_cache: await llm_cache.aupdate(prompt, llm_string, result.generations) return result diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 462e2531cf6..e40ab2e654a 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -403,35 +403,67 @@ def test_invoke() -> None: assert isinstance(result.content, str) -def test_logprobs() -> None: +def test_response_metadata() -> None: llm = ChatOpenAI() result = llm.invoke([HumanMessage(content="I'm PickleRick")], logprobs=True) assert result.response_metadata + assert all( + k in result.response_metadata + for k in ( + "token_usage", + "model_name", + "logprobs", + "system_fingerprint", + "finish_reason", + ) + ) assert "content" in result.response_metadata["logprobs"] -async def test_async_logprobs() -> None: +async def test_async_response_metadata() -> None: llm = ChatOpenAI() result = await llm.ainvoke([HumanMessage(content="I'm PickleRick")], logprobs=True) assert result.response_metadata + assert all( + k in result.response_metadata + for k in ( + "token_usage", + "model_name", + "logprobs", + "system_fingerprint", + "finish_reason", + ) + ) assert "content" in result.response_metadata["logprobs"] -def test_logprobs_streaming() -> None: +def test_response_metadata_streaming() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None for chunk in llm.stream("I'm Pickle Rick", logprobs=True): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk - assert cast(BaseMessageChunk, full).response_metadata + assert all( + k in cast(BaseMessageChunk, full).response_metadata + for k in ( + "logprobs", + "finish_reason", + ) + ) assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] -async def test_async_logprobs_streaming() -> None: +async def test_async_response_metadata_streaming() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None async for chunk in llm.astream("I'm Pickle Rick", logprobs=True): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk - assert cast(BaseMessageChunk, full).response_metadata + assert all( + k in cast(BaseMessageChunk, full).response_metadata + for k in ( + "logprobs", + "finish_reason", + ) + ) assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]