feat(groq): losen restrictions on reasoning_effort, inject effort in meta, update tests (#32415)

This commit is contained in:
Mason Daugherty 2025-08-05 15:03:38 -04:00 committed by GitHub
parent 419c173225
commit fb490b0c39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 36 deletions

View File

@ -6,15 +6,7 @@ import json
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from operator import itemgetter
from typing import (
Any,
Callable,
Literal,
Optional,
TypedDict,
Union,
cast,
)
from typing import Any, Callable, Literal, Optional, TypedDict, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -46,10 +38,7 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
@ -60,23 +49,13 @@ from langchain_core.output_parsers.openai_tools import (
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_groq.version import __version__
@ -122,7 +101,7 @@ class ChatGroq(BaseChatModel):
See the `Groq documentation
<https://console.groq.com/docs/reasoning#reasoning>`__ for more
details and a list of supported reasoning models.
details and a list of supported models.
model_kwargs: Dict[str, Any]
Holds any model parameters valid for create call not
explicitly specified.
@ -328,20 +307,15 @@ class ChatGroq(BaseChatModel):
overridden in ``reasoning_effort``.
See the `Groq documentation <https://console.groq.com/docs/reasoning#reasoning>`__
for more details and a list of supported reasoning models.
for more details and a list of supported models.
"""
reasoning_effort: Optional[Literal["none", "default"]] = Field(default=None)
reasoning_effort: Optional[str] = Field(default=None)
"""The level of effort the model will put into reasoning. Groq will default to
enabling reasoning if left undefined. If set to ``none``, ``reasoning_format`` will
not apply and ``reasoning_content`` will not be returned.
- ``'none'``: Disable reasoning. The model will not use any reasoning tokens when
generating a response.
- ``'default'``: Enable reasoning.
enabling reasoning if left undefined.
See the `Groq documentation
<https://console.groq.com/docs/reasoning#options-for-reasoning-effort>`__ for more
details and a list of models that support setting a reasoning effort.
details and a list of options and models that support setting a reasoning effort.
"""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
@ -601,6 +575,11 @@ class ChatGroq(BaseChatModel):
generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier
reasoning_effort = (
params.get("reasoning_effort") or self.reasoning_effort
)
if reasoning_effort:
generation_info["reasoning_effort"] = reasoning_effort
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
@ -644,6 +623,11 @@ class ChatGroq(BaseChatModel):
generation_info["system_fingerprint"] = system_fingerprint
service_tier = params.get("service_tier") or self.service_tier
generation_info["service_tier"] = service_tier
reasoning_effort = (
params.get("reasoning_effort") or self.reasoning_effort
)
if reasoning_effort:
generation_info["reasoning_effort"] = reasoning_effort
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
@ -714,6 +698,9 @@ class ChatGroq(BaseChatModel):
"system_fingerprint": response.get("system_fingerprint", ""),
}
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
reasoning_effort = params.get("reasoning_effort") or self.reasoning_effort
if reasoning_effort:
llm_output["reasoning_effort"] = reasoning_effort
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(

View File

@ -25,6 +25,8 @@ from tests.unit_tests.fake.callbacks import (
)
DEFAULT_MODEL_NAME = "openai/gpt-oss-20b"
# gpt-oss doesn't support `reasoning_effort`
REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b"
@ -272,7 +274,7 @@ def test_reasoning_output_stream() -> None:
def test_reasoning_effort_none() -> None:
"""Test that no reasoning output is returned if effort is set to none."""
chat = ChatGroq(
model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort
model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort = none
reasoning_effort="none",
)
message = HumanMessage(content="What is the capital of France?")
@ -282,6 +284,79 @@ def test_reasoning_effort_none() -> None:
assert "<think>" not in response.content and "<think/>" not in response.content
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
def test_reasoning_effort_levels(effort: str) -> None:
"""Test reasoning effort options for different levels."""
# As of now, only the new gpt-oss models support `'low'`, `'medium'`, and `'high'`
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
reasoning_effort=effort,
)
message = HumanMessage(content="What is the capital of France?")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
assert response.response_metadata.get("reasoning_effort") == effort
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
def test_reasoning_effort_invoke_override(effort: str) -> None:
"""Test that reasoning_effort in invoke() overrides class-level setting."""
# Create chat with no reasoning effort at class level
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
)
message = HumanMessage(content="What is the capital of France?")
# Override reasoning_effort in invoke()
response = chat.invoke([message], reasoning_effort=effort)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
assert response.response_metadata.get("reasoning_effort") == effort
def test_reasoning_effort_invoke_override_different_level() -> None:
"""Test that reasoning_effort in invoke() overrides class-level setting."""
# Create chat with reasoning effort at class level
chat = ChatGroq(
model=DEFAULT_MODEL_NAME, # openai/gpt-oss-20b supports reasoning_effort
reasoning_effort="high",
)
message = HumanMessage(content="What is the capital of France?")
# Override reasoning_effort to 'low' in invoke()
response = chat.invoke([message], reasoning_effort="low")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
# Should reflect the overridden value, not the class-level setting
assert response.response_metadata.get("reasoning_effort") == "low"
def test_reasoning_effort_streaming() -> None:
"""Test that reasoning_effort is captured in streaming response metadata."""
chat = ChatGroq(
model=DEFAULT_MODEL_NAME,
reasoning_effort="medium",
)
message = HumanMessage(content="What is the capital of France?")
chunks = list(chat.stream([message]))
assert len(chunks) > 0
# Find the final chunk with finish_reason
final_chunk = None
for chunk in chunks:
if chunk.response_metadata.get("finish_reason"):
final_chunk = chunk
break
assert final_chunk is not None
assert final_chunk.response_metadata.get("reasoning_effort") == "medium"
#
# Misc tests
#