feat(groq): losen restrictions on reasoning_effort, inject effort in meta, update tests (#32415)

This commit is contained in:
Mason Daugherty 2025-08-05 15:03:38 -04:00 committed by GitHub
parent 419c173225
commit fb490b0c39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 36 deletions

View File

@ -6,15 +6,7 @@ import json
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import Any, Callable, Literal, Optional, TypedDict, Union, cast
Any,
Callable,
Literal,
Optional,
TypedDict,
Union,
cast,
)
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -46,10 +38,7 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
) )
from langchain_core.output_parsers import ( from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
@ -60,23 +49,13 @@ from langchain_core.output_parsers.openai_tools import (
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import ( from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self from typing_extensions import Self
from langchain_groq.version import __version__ from langchain_groq.version import __version__
@ -122,7 +101,7 @@ class ChatGroq(BaseChatModel):
See the `Groq documentation See the `Groq documentation
<https://console.groq.com/docs/reasoning#reasoning>`__ for more <https://console.groq.com/docs/reasoning#reasoning>`__ for more
details and a list of supported reasoning models. details and a list of supported models.
model_kwargs: Dict[str, Any] model_kwargs: Dict[str, Any]
Holds any model parameters valid for create call not Holds any model parameters valid for create call not
explicitly specified. explicitly specified.
@ -328,20 +307,15 @@ class ChatGroq(BaseChatModel):
overridden in ``reasoning_effort``. overridden in ``reasoning_effort``.
See the `Groq documentation <https://console.groq.com/docs/reasoning#reasoning>`__ See the `Groq documentation <https://console.groq.com/docs/reasoning#reasoning>`__
for more details and a list of supported reasoning models. for more details and a list of supported models.
""" """
reasoning_effort: Optional[Literal["none", "default"]] = Field(default=None) reasoning_effort: Optional[str] = Field(default=None)
"""The level of effort the model will put into reasoning. Groq will default to """The level of effort the model will put into reasoning. Groq will default to
enabling reasoning if left undefined. If set to ``none``, ``reasoning_format`` will enabling reasoning if left undefined.
not apply and ``reasoning_content`` will not be returned.
- ``'none'``: Disable reasoning. The model will not use any reasoning tokens when
generating a response.
- ``'default'``: Enable reasoning.
See the `Groq documentation See the `Groq documentation
<https://console.groq.com/docs/reasoning#options-for-reasoning-effort>`__ for more <https://console.groq.com/docs/reasoning#options-for-reasoning-effort>`__ for more
details and a list of models that support setting a reasoning effort. details and a list of options and models that support setting a reasoning effort.
""" """
model_kwargs: dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
@ -601,6 +575,11 @@ class ChatGroq(BaseChatModel):
generation_info["system_fingerprint"] = system_fingerprint generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier generation_info["service_tier"] = service_tier
reasoning_effort = (
params.get("reasoning_effort") or self.reasoning_effort
)
if reasoning_effort:
generation_info["reasoning_effort"] = reasoning_effort
logprobs = choice.get("logprobs") logprobs = choice.get("logprobs")
if logprobs: if logprobs:
generation_info["logprobs"] = logprobs generation_info["logprobs"] = logprobs
@ -644,6 +623,11 @@ class ChatGroq(BaseChatModel):
generation_info["system_fingerprint"] = system_fingerprint generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier generation_info["service_tier"] = service_tier
reasoning_effort = (
params.get("reasoning_effort") or self.reasoning_effort
)
if reasoning_effort:
generation_info["reasoning_effort"] = reasoning_effort
logprobs = choice.get("logprobs") logprobs = choice.get("logprobs")
if logprobs: if logprobs:
generation_info["logprobs"] = logprobs generation_info["logprobs"] = logprobs
@ -714,6 +698,9 @@ class ChatGroq(BaseChatModel):
"system_fingerprint": response.get("system_fingerprint", ""), "system_fingerprint": response.get("system_fingerprint", ""),
} }
llm_output["service_tier"] = params.get("service_tier") or self.service_tier llm_output["service_tier"] = params.get("service_tier") or self.service_tier
reasoning_effort = params.get("reasoning_effort") or self.reasoning_effort
if reasoning_effort:
llm_output["reasoning_effort"] = reasoning_effort
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts( def _create_message_dicts(

View File

@ -25,6 +25,8 @@ from tests.unit_tests.fake.callbacks import (
) )
DEFAULT_MODEL_NAME = "openai/gpt-oss-20b" DEFAULT_MODEL_NAME = "openai/gpt-oss-20b"
# gpt-oss doesn't support `reasoning_effort`
REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b" REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b"
@ -272,7 +274,7 @@ def test_reasoning_output_stream() -> None:
def test_reasoning_effort_none() -> None: def test_reasoning_effort_none() -> None:
"""Test that no reasoning output is returned if effort is set to none.""" """Test that no reasoning output is returned if effort is set to none."""
chat = ChatGroq( chat = ChatGroq(
model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort = none
reasoning_effort="none", reasoning_effort="none",
) )
message = HumanMessage(content="What is the capital of France?") message = HumanMessage(content="What is the capital of France?")
@ -282,6 +284,79 @@ def test_reasoning_effort_none() -> None:
assert "<think>" not in response.content and "<think/>" not in response.content assert "<think>" not in response.content and "<think/>" not in response.content
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
def test_reasoning_effort_levels(effort: str) -> None:
"""Test reasoning effort options for different levels."""
# As of now, only the new gpt-oss models support `'low'`, `'medium'`, and `'high'`
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
reasoning_effort=effort,
)
message = HumanMessage(content="What is the capital of France?")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
assert response.response_metadata.get("reasoning_effort") == effort
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
def test_reasoning_effort_invoke_override(effort: str) -> None:
"""Test that reasoning_effort in invoke() overrides class-level setting."""
# Create chat with no reasoning effort at class level
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
)
message = HumanMessage(content="What is the capital of France?")
# Override reasoning_effort in invoke()
response = chat.invoke([message], reasoning_effort=effort)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
assert response.response_metadata.get("reasoning_effort") == effort
def test_reasoning_effort_invoke_override_different_level() -> None:
"""Test that reasoning_effort in invoke() overrides class-level setting."""
# Create chat with reasoning effort at class level
chat = ChatGroq(
model=DEFAULT_MODEL_NAME, # openai/gpt-oss-20b supports reasoning_effort
reasoning_effort="high",
)
message = HumanMessage(content="What is the capital of France?")
# Override reasoning_effort to 'low' in invoke()
response = chat.invoke([message], reasoning_effort="low")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
# Should reflect the overridden value, not the class-level setting
assert response.response_metadata.get("reasoning_effort") == "low"
def test_reasoning_effort_streaming() -> None:
"""Test that reasoning_effort is captured in streaming response metadata."""
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
reasoning_effort="medium",
)
message = HumanMessage(content="What is the capital of France?")
chunks = list(chat.stream([message]))
assert len(chunks) > 0
# Find the final chunk with finish_reason
final_chunk = None
for chunk in chunks:
if chunk.response_metadata.get("finish_reason"):
final_chunk = chunk
break
assert final_chunk is not None
assert final_chunk.response_metadata.get("reasoning_effort") == "medium"
# #
# Misc tests # Misc tests
# #