mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
feat(groq): losen restrictions on reasoning_effort
, inject effort in meta, update tests (#32415)
This commit is contained in:
parent
419c173225
commit
fb490b0c39
@ -6,15 +6,7 @@ import json
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -46,10 +38,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
@ -60,23 +49,13 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
secret_from_env,
|
||||
)
|
||||
from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_groq.version import __version__
|
||||
@ -122,7 +101,7 @@ class ChatGroq(BaseChatModel):
|
||||
|
||||
See the `Groq documentation
|
||||
<https://console.groq.com/docs/reasoning#reasoning>`__ for more
|
||||
details and a list of supported reasoning models.
|
||||
details and a list of supported models.
|
||||
model_kwargs: Dict[str, Any]
|
||||
Holds any model parameters valid for create call not
|
||||
explicitly specified.
|
||||
@ -328,20 +307,15 @@ class ChatGroq(BaseChatModel):
|
||||
overridden in ``reasoning_effort``.
|
||||
|
||||
See the `Groq documentation <https://console.groq.com/docs/reasoning#reasoning>`__
|
||||
for more details and a list of supported reasoning models.
|
||||
for more details and a list of supported models.
|
||||
"""
|
||||
reasoning_effort: Optional[Literal["none", "default"]] = Field(default=None)
|
||||
reasoning_effort: Optional[str] = Field(default=None)
|
||||
"""The level of effort the model will put into reasoning. Groq will default to
|
||||
enabling reasoning if left undefined. If set to ``none``, ``reasoning_format`` will
|
||||
not apply and ``reasoning_content`` will not be returned.
|
||||
|
||||
- ``'none'``: Disable reasoning. The model will not use any reasoning tokens when
|
||||
generating a response.
|
||||
- ``'default'``: Enable reasoning.
|
||||
enabling reasoning if left undefined.
|
||||
|
||||
See the `Groq documentation
|
||||
<https://console.groq.com/docs/reasoning#options-for-reasoning-effort>`__ for more
|
||||
details and a list of models that support setting a reasoning effort.
|
||||
details and a list of options and models that support setting a reasoning effort.
|
||||
"""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
@ -601,6 +575,11 @@ class ChatGroq(BaseChatModel):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
service_tier = params.get("service_tier") or self.service_tier
|
||||
generation_info["service_tier"] = service_tier
|
||||
reasoning_effort = (
|
||||
params.get("reasoning_effort") or self.reasoning_effort
|
||||
)
|
||||
if reasoning_effort:
|
||||
generation_info["reasoning_effort"] = reasoning_effort
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
@ -644,6 +623,11 @@ class ChatGroq(BaseChatModel):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
service_tier = params.get("service_tier") or self.service_tier
|
||||
generation_info["service_tier"] = service_tier
|
||||
reasoning_effort = (
|
||||
params.get("reasoning_effort") or self.reasoning_effort
|
||||
)
|
||||
if reasoning_effort:
|
||||
generation_info["reasoning_effort"] = reasoning_effort
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
@ -714,6 +698,9 @@ class ChatGroq(BaseChatModel):
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
llm_output["service_tier"] = params.get("service_tier") or self.service_tier
|
||||
reasoning_effort = params.get("reasoning_effort") or self.reasoning_effort
|
||||
if reasoning_effort:
|
||||
llm_output["reasoning_effort"] = reasoning_effort
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
|
@ -25,6 +25,8 @@ from tests.unit_tests.fake.callbacks import (
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
# gpt-oss doesn't support `reasoning_effort`
|
||||
REASONING_MODEL_NAME = "deepseek-r1-distill-llama-70b"
|
||||
|
||||
|
||||
@ -272,7 +274,7 @@ def test_reasoning_output_stream() -> None:
|
||||
def test_reasoning_effort_none() -> None:
|
||||
"""Test that no reasoning output is returned if effort is set to none."""
|
||||
chat = ChatGroq(
|
||||
model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort
|
||||
model="qwen/qwen3-32b", # Only qwen3 currently supports reasoning_effort = none
|
||||
reasoning_effort="none",
|
||||
)
|
||||
message = HumanMessage(content="What is the capital of France?")
|
||||
@ -282,6 +284,79 @@ def test_reasoning_effort_none() -> None:
|
||||
assert "<think>" not in response.content and "<think/>" not in response.content
|
||||
|
||||
|
||||
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
|
||||
def test_reasoning_effort_levels(effort: str) -> None:
|
||||
"""Test reasoning effort options for different levels."""
|
||||
# As of now, only the new gpt-oss models support `'low'`, `'medium'`, and `'high'`
|
||||
chat = ChatGroq(
|
||||
model=DEFAULT_MODEL_NAME,
|
||||
reasoning_effort=effort,
|
||||
)
|
||||
message = HumanMessage(content="What is the capital of France?")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert len(response.content) > 0
|
||||
assert response.response_metadata.get("reasoning_effort") == effort
|
||||
|
||||
|
||||
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
|
||||
def test_reasoning_effort_invoke_override(effort: str) -> None:
|
||||
"""Test that reasoning_effort in invoke() overrides class-level setting."""
|
||||
# Create chat with no reasoning effort at class level
|
||||
chat = ChatGroq(
|
||||
model=DEFAULT_MODEL_NAME,
|
||||
)
|
||||
message = HumanMessage(content="What is the capital of France?")
|
||||
|
||||
# Override reasoning_effort in invoke()
|
||||
response = chat.invoke([message], reasoning_effort=effort)
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert len(response.content) > 0
|
||||
assert response.response_metadata.get("reasoning_effort") == effort
|
||||
|
||||
|
||||
def test_reasoning_effort_invoke_override_different_level() -> None:
|
||||
"""Test that reasoning_effort in invoke() overrides class-level setting."""
|
||||
# Create chat with reasoning effort at class level
|
||||
chat = ChatGroq(
|
||||
model=DEFAULT_MODEL_NAME, # openai/gpt-oss-20b supports reasoning_effort
|
||||
reasoning_effort="high",
|
||||
)
|
||||
message = HumanMessage(content="What is the capital of France?")
|
||||
|
||||
# Override reasoning_effort to 'low' in invoke()
|
||||
response = chat.invoke([message], reasoning_effort="low")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert len(response.content) > 0
|
||||
# Should reflect the overridden value, not the class-level setting
|
||||
assert response.response_metadata.get("reasoning_effort") == "low"
|
||||
|
||||
|
||||
def test_reasoning_effort_streaming() -> None:
|
||||
"""Test that reasoning_effort is captured in streaming response metadata."""
|
||||
chat = ChatGroq(
|
||||
model=DEFAULT_MODEL_NAME,
|
||||
reasoning_effort="medium",
|
||||
)
|
||||
message = HumanMessage(content="What is the capital of France?")
|
||||
|
||||
chunks = list(chat.stream([message]))
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Find the final chunk with finish_reason
|
||||
final_chunk = None
|
||||
for chunk in chunks:
|
||||
if chunk.response_metadata.get("finish_reason"):
|
||||
final_chunk = chunk
|
||||
break
|
||||
|
||||
assert final_chunk is not None
|
||||
assert final_chunk.response_metadata.get("reasoning_effort") == "medium"
|
||||
|
||||
|
||||
#
|
||||
# Misc tests
|
||||
#
|
||||
|
Loading…
Reference in New Issue
Block a user