From fb490b0c390b79ff75489b04de264077ed7ee45f Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Tue, 5 Aug 2025 15:03:38 -0400 Subject: [PATCH] feat(groq): losen restrictions on `reasoning_effort`, inject effort in meta, update tests (#32415) --- .../groq/langchain_groq/chat_models.py | 57 ++++++-------- .../integration_tests/test_chat_models.py | 77 ++++++++++++++++++- 2 files changed, 98 insertions(+), 36 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index c56da31d6aa..443692dbe6c 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -6,15 +6,7 @@ import json import warnings from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from operator import itemgetter -from typing import ( - Any, - Callable, - Literal, - Optional, - TypedDict, - Union, - cast, -) +from typing import Any, Callable, Literal, Optional, TypedDict, Union, cast from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -46,10 +38,7 @@ from langchain_core.messages import ( ToolMessage, ToolMessageChunk, ) -from langchain_core.output_parsers import ( - JsonOutputParser, - PydanticOutputParser, -) +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, @@ -60,23 +49,13 @@ from langchain_core.output_parsers.openai_tools import ( from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import ( - from_env, - get_pydantic_field_names, - secret_from_env, -) +from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env from langchain_core.utils.function_calling import ( convert_to_openai_function, convert_to_openai_tool, ) from langchain_core.utils.pydantic import is_basemodel_subclass -from pydantic import ( - BaseModel, - ConfigDict, - Field, - SecretStr, - model_validator, -) +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self from langchain_groq.version import __version__ @@ -122,7 +101,7 @@ class ChatGroq(BaseChatModel): See the `Groq documentation `__ for more - details and a list of supported reasoning models. + details and a list of supported models. model_kwargs: Dict[str, Any] Holds any model parameters valid for create call not explicitly specified. @@ -328,20 +307,15 @@ class ChatGroq(BaseChatModel): overridden in ``reasoning_effort``. See the `Groq documentation `__ - for more details and a list of supported reasoning models. + for more details and a list of supported models. """ - reasoning_effort: Optional[Literal["none", "default"]] = Field(default=None) + reasoning_effort: Optional[str] = Field(default=None) """The level of effort the model will put into reasoning. Groq will default to - enabling reasoning if left undefined. If set to ``none``, ``reasoning_format`` will - not apply and ``reasoning_content`` will not be returned. - - - ``'none'``: Disable reasoning. The model will not use any reasoning tokens when - generating a response. - - ``'default'``: Enable reasoning. + enabling reasoning if left undefined. See the `Groq documentation `__ for more - details and a list of models that support setting a reasoning effort. + details and a list of options and models that support setting a reasoning effort. """ model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -601,6 +575,11 @@ class ChatGroq(BaseChatModel): generation_info["system_fingerprint"] = system_fingerprint service_tier = params.get("service_tier") or self.service_tier generation_info["service_tier"] = service_tier + reasoning_effort = ( + params.get("reasoning_effort") or self.reasoning_effort + ) + if reasoning_effort: + generation_info["reasoning_effort"] = reasoning_effort logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -644,6 +623,11 @@ class ChatGroq(BaseChatModel): generation_info["system_fingerprint"] = system_fingerprint service_tier = params.get("service_tier") or self.service_tier generation_info["service_tier"] = service_tier + reasoning_effort = ( + params.get("reasoning_effort") or self.reasoning_effort + ) + if reasoning_effort: + generation_info["reasoning_effort"] = reasoning_effort logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs @@ -714,6 +698,9 @@ class ChatGroq(BaseChatModel): "system_fingerprint": response.get("system_fingerprint", ""), } llm_output["service_tier"] = params.get("service_tier") or self.service_tier + reasoning_effort = params.get("reasoning_effort") or self.reasoning_effort + if reasoning_effort: + llm_output["reasoning_effort"] = reasoning_effort return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index ec63ab96071..e35e1108505 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -25,6 +25,8 @@ from tests.unit_tests.fake.callbacks import ( ) DEFAULT_MODEL_NAME = "openai/gpt-oss-20b" + +# gpt-oss doesn't support `reasoning_effort` REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b" @@ -272,7 +274,7 @@ def test_reasoning_output_stream() -> None: def test_reasoning_effort_none() -> None: """Test that no reasoning output is returned if effort is set to none.""" chat = ChatGroq( - model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort + model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort = none reasoning_effort="none", ) message = HumanMessage(content="What is the capital of France?") @@ -282,6 +284,79 @@ def test_reasoning_effort_none() -> None: assert "" not in response.content and "" not in response.content +@pytest.mark.parametrize("effort", ["low", "medium", "high"]) +def test_reasoning_effort_levels(effort: str) -> None: + """Test reasoning effort options for different levels.""" + # As of now, only the new gpt-oss models support `'low'`, `'medium'`, and `'high'` + chat = ChatGroq( + model=DEFAULT_MODEL_NAME, + reasoning_effort=effort, + ) + message = HumanMessage(content="What is the capital of France?") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert len(response.content) > 0 + assert response.response_metadata.get("reasoning_effort") == effort + + +@pytest.mark.parametrize("effort", ["low", "medium", "high"]) +def test_reasoning_effort_invoke_override(effort: str) -> None: + """Test that reasoning_effort in invoke() overrides class-level setting.""" + # Create chat with no reasoning effort at class level + chat = ChatGroq( + model=DEFAULT_MODEL_NAME, + ) + message = HumanMessage(content="What is the capital of France?") + + # Override reasoning_effort in invoke() + response = chat.invoke([message], reasoning_effort=effort) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert len(response.content) > 0 + assert response.response_metadata.get("reasoning_effort") == effort + + +def test_reasoning_effort_invoke_override_different_level() -> None: + """Test that reasoning_effort in invoke() overrides class-level setting.""" + # Create chat with reasoning effort at class level + chat = ChatGroq( + model=DEFAULT_MODEL_NAME, # openai/gpt-oss-20b supports reasoning_effort + reasoning_effort="high", + ) + message = HumanMessage(content="What is the capital of France?") + + # Override reasoning_effort to 'low' in invoke() + response = chat.invoke([message], reasoning_effort="low") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert len(response.content) > 0 + # Should reflect the overridden value, not the class-level setting + assert response.response_metadata.get("reasoning_effort") == "low" + + +def test_reasoning_effort_streaming() -> None: + """Test that reasoning_effort is captured in streaming response metadata.""" + chat = ChatGroq( + model=DEFAULT_MODEL_NAME, + reasoning_effort="medium", + ) + message = HumanMessage(content="What is the capital of France?") + + chunks = list(chat.stream([message])) + assert len(chunks) > 0 + + # Find the final chunk with finish_reason + final_chunk = None + for chunk in chunks: + if chunk.response_metadata.get("finish_reason"): + final_chunk = chunk + break + + assert final_chunk is not None + assert final_chunk.response_metadata.get("reasoning_effort") == "medium" + + # # Misc tests #