diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 240c1ae2d0b..fe61870bfef 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -107,6 +107,12 @@ class ChatGroq(BaseChatModel): Sampling temperature. Ranges from 0.0 to 1.0. max_tokens: Optional[int] Max number of tokens to generate. + reasoning_format: Optional[Literal["parsed", "raw", "hidden]] + The format for reasoning output. + + - ``parsed``: Separates reasoning into a dedicated field while keeping the response concise. + - ``raw``: Includes reasoning within think tags in the content. + - ``hidden``: Returns only the final answer. model_kwargs: Dict[str, Any] Holds any model parameters valid for create call not explicitly specified. @@ -292,7 +298,7 @@ class ChatGroq(BaseChatModel): 'system_fingerprint': 'fp_c5f20b5bb1', 'finish_reason': 'stop', 'logprobs': None} - """ + """ # noqa: E501 client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: @@ -302,6 +308,13 @@ class ChatGroq(BaseChatModel): """What sampling temperature to use.""" stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" + reasoning_format: Optional[Literal["parsed", "raw", "hidden"]] = None + """The format for reasoning output. + + - ``parsed``: Separates reasoning into a dedicated field while keeping the response concise. + - ``raw``: Includes reasoning within think tags in the content. + - ``hidden``: Returns only the final answer. + """ # noqa: E501 model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" groq_api_key: Optional[SecretStr] = Field( @@ -606,6 +619,7 @@ class ChatGroq(BaseChatModel): "n": self.n, "temperature": self.temperature, "stop": self.stop, + "reasoning_format": self.reasoning_format, **self.model_kwargs, } if self.max_tokens is not None: @@ -1153,6 +1167,8 @@ def _convert_chunk_to_message_chunk( if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: + if reasoning := _dict.get("reasoning"): + additional_kwargs["reasoning_content"] = reasoning if usage := (chunk.get("x_groq") or {}).get("usage"): input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) @@ -1196,6 +1212,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: elif role == "assistant": content = _dict.get("content", "") or "" additional_kwargs: dict = {} + if reasoning := _dict.get("reasoning"): + additional_kwargs["reasoning_content"] = reasoning if function_call := _dict.get("function_call"): additional_kwargs["function_call"] = dict(function_call) tool_calls = [] diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 4f115336f5c..8fe51839bb1 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test ChatGroq chat model.""" import json -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from langchain_core.messages import ( @@ -212,6 +212,58 @@ async def test_agenerate_streaming() -> None: assert generation.text == generation.message.content +# +# Test reasoning output +# +def test_reasoning_output_invoke() -> None: + """Test reasoning output from ChatGroq with invoke.""" + chat = ChatGroq( + model="deepseek-r1-distill-llama-70b", + reasoning_format="parsed", + ) + message = [ + SystemMessage( + content="You are a helpful assistant that translates English to French." + ), + HumanMessage(content="I love programming."), + ] + response = chat.invoke(message) + assert isinstance(response, AIMessage) + assert "reasoning_content" in response.additional_kwargs + assert isinstance(response.additional_kwargs["reasoning_content"], str) + assert len(response.additional_kwargs["reasoning_content"]) > 0 + + +def test_reasoning_output_stream() -> None: + """Test reasoning output from ChatGroq with stream.""" + chat = ChatGroq( + model="deepseek-r1-distill-llama-70b", + reasoning_format="parsed", + ) + message = [ + SystemMessage( + content="You are a helpful assistant that translates English to French." + ), + HumanMessage(content="I love programming."), + ] + + full_response: Optional[AIMessageChunk] = None + for token in chat.stream(message): + assert isinstance(token, AIMessageChunk) + + if full_response is None: + full_response = token + else: + # Casting since adding results in a type error + full_response = cast(AIMessageChunk, full_response + token) + + assert full_response is not None + assert isinstance(full_response, AIMessageChunk) + assert "reasoning_content" in full_response.additional_kwargs + assert isinstance(full_response.additional_kwargs["reasoning_content"], str) + assert len(full_response.additional_kwargs["reasoning_content"]) > 0 + + # # Misc tests #