mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
groq: add support for accessing reasoning output from Groq models (#31662)
**Description:** return [reasoning](https://console.groq.com/docs/reasoning) output in `additional_kwargs` as `reasoning_content` **Issue:** Resolves #31052
This commit is contained in:
parent
af2188b848
commit
dcf5c7b472
@ -107,6 +107,12 @@ class ChatGroq(BaseChatModel):
|
|||||||
Sampling temperature. Ranges from 0.0 to 1.0.
|
Sampling temperature. Ranges from 0.0 to 1.0.
|
||||||
max_tokens: Optional[int]
|
max_tokens: Optional[int]
|
||||||
Max number of tokens to generate.
|
Max number of tokens to generate.
|
||||||
|
reasoning_format: Optional[Literal["parsed", "raw", "hidden]]
|
||||||
|
The format for reasoning output.
|
||||||
|
|
||||||
|
- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
|
||||||
|
- ``raw``: Includes reasoning within think tags in the content.
|
||||||
|
- ``hidden``: Returns only the final answer.
|
||||||
model_kwargs: Dict[str, Any]
|
model_kwargs: Dict[str, Any]
|
||||||
Holds any model parameters valid for create call not
|
Holds any model parameters valid for create call not
|
||||||
explicitly specified.
|
explicitly specified.
|
||||||
@ -292,7 +298,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
'system_fingerprint': 'fp_c5f20b5bb1',
|
'system_fingerprint': 'fp_c5f20b5bb1',
|
||||||
'finish_reason': 'stop',
|
'finish_reason': 'stop',
|
||||||
'logprobs': None}
|
'logprobs': None}
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
@ -302,6 +308,13 @@ class ChatGroq(BaseChatModel):
|
|||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
|
||||||
"""Default stop sequences."""
|
"""Default stop sequences."""
|
||||||
|
reasoning_format: Optional[Literal["parsed", "raw", "hidden"]] = None
|
||||||
|
"""The format for reasoning output.
|
||||||
|
|
||||||
|
- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
|
||||||
|
- ``raw``: Includes reasoning within think tags in the content.
|
||||||
|
- ``hidden``: Returns only the final answer.
|
||||||
|
""" # noqa: E501
|
||||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
groq_api_key: Optional[SecretStr] = Field(
|
groq_api_key: Optional[SecretStr] = Field(
|
||||||
@ -606,6 +619,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
"n": self.n,
|
"n": self.n,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
|
"reasoning_format": self.reasoning_format,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
@ -1153,6 +1167,8 @@ def _convert_chunk_to_message_chunk(
|
|||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
if reasoning := _dict.get("reasoning"):
|
||||||
|
additional_kwargs["reasoning_content"] = reasoning
|
||||||
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
||||||
input_tokens = usage.get("prompt_tokens", 0)
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
output_tokens = usage.get("completion_tokens", 0)
|
output_tokens = usage.get("completion_tokens", 0)
|
||||||
@ -1196,6 +1212,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
content = _dict.get("content", "") or ""
|
content = _dict.get("content", "") or ""
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
|
if reasoning := _dict.get("reasoning"):
|
||||||
|
additional_kwargs["reasoning_content"] = reasoning
|
||||||
if function_call := _dict.get("function_call"):
|
if function_call := _dict.get("function_call"):
|
||||||
additional_kwargs["function_call"] = dict(function_call)
|
additional_kwargs["function_call"] = dict(function_call)
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test ChatGroq chat model."""
|
"""Test ChatGroq chat model."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -212,6 +212,58 @@ async def test_agenerate_streaming() -> None:
|
|||||||
assert generation.text == generation.message.content
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Test reasoning output
|
||||||
|
#
|
||||||
|
def test_reasoning_output_invoke() -> None:
|
||||||
|
"""Test reasoning output from ChatGroq with invoke."""
|
||||||
|
chat = ChatGroq(
|
||||||
|
model="deepseek-r1-distill-llama-70b",
|
||||||
|
reasoning_format="parsed",
|
||||||
|
)
|
||||||
|
message = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant that translates English to French."
|
||||||
|
),
|
||||||
|
HumanMessage(content="I love programming."),
|
||||||
|
]
|
||||||
|
response = chat.invoke(message)
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert "reasoning_content" in response.additional_kwargs
|
||||||
|
assert isinstance(response.additional_kwargs["reasoning_content"], str)
|
||||||
|
assert len(response.additional_kwargs["reasoning_content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_reasoning_output_stream() -> None:
|
||||||
|
"""Test reasoning output from ChatGroq with stream."""
|
||||||
|
chat = ChatGroq(
|
||||||
|
model="deepseek-r1-distill-llama-70b",
|
||||||
|
reasoning_format="parsed",
|
||||||
|
)
|
||||||
|
message = [
|
||||||
|
SystemMessage(
|
||||||
|
content="You are a helpful assistant that translates English to French."
|
||||||
|
),
|
||||||
|
HumanMessage(content="I love programming."),
|
||||||
|
]
|
||||||
|
|
||||||
|
full_response: Optional[AIMessageChunk] = None
|
||||||
|
for token in chat.stream(message):
|
||||||
|
assert isinstance(token, AIMessageChunk)
|
||||||
|
|
||||||
|
if full_response is None:
|
||||||
|
full_response = token
|
||||||
|
else:
|
||||||
|
# Casting since adding results in a type error
|
||||||
|
full_response = cast(AIMessageChunk, full_response + token)
|
||||||
|
|
||||||
|
assert full_response is not None
|
||||||
|
assert isinstance(full_response, AIMessageChunk)
|
||||||
|
assert "reasoning_content" in full_response.additional_kwargs
|
||||||
|
assert isinstance(full_response.additional_kwargs["reasoning_content"], str)
|
||||||
|
assert len(full_response.additional_kwargs["reasoning_content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Misc tests
|
# Misc tests
|
||||||
#
|
#
|
||||||
|
Loading…
Reference in New Issue
Block a user