diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 560e9a5d479..82b69076466 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -541,6 +541,9 @@ class ChatGroq(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -579,6 +582,9 @@ class ChatGroq(BaseChatModel): generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason + generation_info["model_name"] = self.model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 9d74ef4f2ab..4f115336f5c 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -98,16 +98,19 @@ async def test_astream() -> None: full: Optional[BaseMessageChunk] = None chunks_with_token_counts = 0 + chunks_with_response_metadata = 0 async for token in chat.astream("Welcome to the Groqetship!"): assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if chunks_with_token_counts != 1: + if token.response_metadata: + chunks_with_response_metadata += 1 + if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: raise AssertionError( - "Expected exactly one chunk with token counts. " - "AIMessageChunk aggregation adds counts. Check that " + "Expected exactly one chunk with token counts or metadata. " + "AIMessageChunk aggregation adds / appends these metadata. Check that " "this is behaving properly." ) assert isinstance(full, AIMessageChunk) @@ -118,6 +121,8 @@ async def test_astream() -> None: full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] == full.usage_metadata["total_tokens"] ) + for expected_metadata in ["model_name", "system_fingerprint"]: + assert full.response_metadata[expected_metadata] #