mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-16 10:03:39 +00:00
feat(groq): Strict Mode for Groq (#35029)
This commit is contained in:
committed by
GitHub
parent
fb31c91076
commit
3101794dde
@@ -75,6 +75,12 @@ from langchain_groq.data._profiles import _PROFILES
|
||||
from langchain_groq.version import __version__
|
||||
|
||||
_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)
|
||||
_STRICT_STRUCTURED_OUTPUT_MODELS = frozenset(
|
||||
{
|
||||
"openai/gpt-oss-20b",
|
||||
"openai/gpt-oss-120b",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_default_model_profile(model_name: str) -> ModelProfile:
|
||||
@@ -907,6 +913,7 @@ class ChatGroq(BaseChatModel):
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
@@ -979,6 +986,16 @@ class ChatGroq(BaseChatModel):
|
||||
The final output is always a `dict` with keys `'raw'`, `'parsed'`, and
|
||||
`'parsing_error'`.
|
||||
|
||||
strict:
|
||||
Only used with `method="json_schema"`. When `True`, Groq's Structured
|
||||
Output API uses constrained decoding to guarantee schema compliance.
|
||||
This requires every object to set `additionalProperties: false` and
|
||||
all properties to be listed in `required`. When `False`, schema
|
||||
adherence is best-effort. If `None`, the argument is omitted.
|
||||
|
||||
Strict mode is only supported for `openai/gpt-oss-20b` and
|
||||
`openai/gpt-oss-120b`. For other models, `strict=True` is ignored.
|
||||
|
||||
kwargs:
|
||||
Any additional parameters to pass to the `langchain.runnable.Runnable`
|
||||
constructor.
|
||||
@@ -1172,7 +1189,6 @@ class ChatGroq(BaseChatModel):
|
||||
```
|
||||
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
@@ -1210,14 +1226,25 @@ class ChatGroq(BaseChatModel):
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
json_schema = convert_to_json_schema(schema)
|
||||
if (
|
||||
strict is True
|
||||
and self.model_name not in _STRICT_STRUCTURED_OUTPUT_MODELS
|
||||
):
|
||||
# Ignore unsupported strict=True to preserve backward compatibility.
|
||||
strict = None
|
||||
json_schema = convert_to_json_schema(schema, strict=strict)
|
||||
schema_name = json_schema.get("title", "")
|
||||
response_format = {
|
||||
response_format: dict[str, Any] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": schema_name, "schema": json_schema},
|
||||
}
|
||||
if strict is not None:
|
||||
response_format["json_schema"]["strict"] = strict
|
||||
ls_format_kwargs: dict[str, Any] = {"method": "json_schema"}
|
||||
if strict is not None:
|
||||
ls_format_kwargs["strict"] = strict
|
||||
ls_format_info = {
|
||||
"kwargs": {"method": "json_schema"},
|
||||
"kwargs": ls_format_kwargs,
|
||||
"schema": json_schema,
|
||||
}
|
||||
llm = self.bind(
|
||||
@@ -1247,8 +1274,9 @@ class ChatGroq(BaseChatModel):
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
"Unrecognized method argument. Expected one of "
|
||||
"'function_calling', 'json_mode', or 'json_schema'. "
|
||||
f"Received: '{method}'"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.runnables import RunnableBinding, RunnableSequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_groq.chat_models import (
|
||||
ChatGroq,
|
||||
@@ -256,6 +258,49 @@ def test_chat_groq_invalid_streaming_params() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_with_structured_output_json_schema_strict() -> None:
|
||||
class Response(BaseModel):
|
||||
"""Response schema."""
|
||||
|
||||
foo: str
|
||||
|
||||
structured_model = ChatGroq(model="openai/gpt-oss-20b").with_structured_output(
|
||||
Response, method="json_schema", strict=True
|
||||
)
|
||||
|
||||
assert isinstance(structured_model, RunnableSequence)
|
||||
first_step = structured_model.steps[0]
|
||||
assert isinstance(first_step, RunnableBinding)
|
||||
response_format = first_step.kwargs["response_format"]
|
||||
assert response_format["type"] == "json_schema"
|
||||
json_schema = response_format["json_schema"]
|
||||
assert json_schema["strict"] is True
|
||||
assert json_schema["name"] == "Response"
|
||||
assert json_schema["schema"]["properties"]["foo"]["type"] == "string"
|
||||
assert "foo" in json_schema["schema"]["required"]
|
||||
assert json_schema["schema"]["additionalProperties"] is False
|
||||
|
||||
|
||||
def test_with_structured_output_json_schema_strict_ignored_on_unsupported_model() -> (
|
||||
None
|
||||
):
|
||||
class Response(BaseModel):
|
||||
"""Response schema."""
|
||||
|
||||
foo: str
|
||||
|
||||
structured_model = ChatGroq(model="llama-3.1-8b-instant").with_structured_output(
|
||||
Response, method="json_schema", strict=True
|
||||
)
|
||||
|
||||
assert isinstance(structured_model, RunnableSequence)
|
||||
first_step = structured_model.steps[0]
|
||||
assert isinstance(first_step, RunnableBinding)
|
||||
response_format = first_step.kwargs["response_format"]
|
||||
assert response_format["type"] == "json_schema"
|
||||
assert "strict" not in response_format["json_schema"]
|
||||
|
||||
|
||||
def test_chat_groq_secret() -> None:
|
||||
"""Test that secret is not printed."""
|
||||
secret = "secretKey" # noqa: S105
|
||||
|
||||
Reference in New Issue
Block a user