diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index b221360a166..841f5c0a14b 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -75,6 +75,12 @@ from langchain_groq.data._profiles import _PROFILES from langchain_groq.version import __version__ _MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES) +_STRICT_STRUCTURED_OUTPUT_MODELS = frozenset( + { + "openai/gpt-oss-20b", + "openai/gpt-oss-120b", + } +) def _get_default_model_profile(model_name: str) -> ModelProfile: @@ -907,6 +913,7 @@ class ChatGroq(BaseChatModel): "function_calling", "json_mode", "json_schema" ] = "function_calling", include_raw: bool = False, + strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, dict | BaseModel]: r"""Model wrapper that returns outputs formatted to match the given schema. @@ -979,6 +986,16 @@ class ChatGroq(BaseChatModel): The final output is always a `dict` with keys `'raw'`, `'parsed'`, and `'parsing_error'`. + strict: + Only used with `method="json_schema"`. When `True`, Groq's Structured + Output API uses constrained decoding to guarantee schema compliance. + This requires every object to set `additionalProperties: false` and + all properties to be listed in `required`. When `False`, schema + adherence is best-effort. If `None`, the argument is omitted. + + Strict mode is only supported for `openai/gpt-oss-20b` and + `openai/gpt-oss-120b`. For other models, `strict=True` is ignored. + kwargs: Any additional parameters to pass to the `langchain.runnable.Runnable` constructor. @@ -1172,7 +1189,6 @@ class ChatGroq(BaseChatModel): ``` """ # noqa: E501 - _ = kwargs.pop("strict", None) is_pydantic_schema = _is_pydantic_class(schema) if method == "function_calling": if schema is None: @@ -1210,14 +1226,25 @@ class ChatGroq(BaseChatModel): "Received None." ) raise ValueError(msg) - json_schema = convert_to_json_schema(schema) + if ( + strict is True + and self.model_name not in _STRICT_STRUCTURED_OUTPUT_MODELS + ): + # Ignore unsupported strict=True to preserve backward compatibility. + strict = None + json_schema = convert_to_json_schema(schema, strict=strict) schema_name = json_schema.get("title", "") - response_format = { + response_format: dict[str, Any] = { "type": "json_schema", "json_schema": {"name": schema_name, "schema": json_schema}, } + if strict is not None: + response_format["json_schema"]["strict"] = strict + ls_format_kwargs: dict[str, Any] = {"method": "json_schema"} + if strict is not None: + ls_format_kwargs["strict"] = strict ls_format_info = { - "kwargs": {"method": "json_schema"}, + "kwargs": ls_format_kwargs, "schema": json_schema, } llm = self.bind( @@ -1247,8 +1274,9 @@ class ChatGroq(BaseChatModel): ) else: msg = ( - f"Unrecognized method argument. Expected one of 'function_calling' or " - f"'json_mode'. Received: '{method}'" + "Unrecognized method argument. Expected one of " + "'function_calling', 'json_mode', or 'json_schema'. " + f"Received: '{method}'" ) raise ValueError(msg) diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index b4edc6bf0ac..1848f5f7ff3 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -16,6 +16,8 @@ from langchain_core.messages import ( SystemMessage, ToolCall, ) +from langchain_core.runnables import RunnableBinding, RunnableSequence +from pydantic import BaseModel from langchain_groq.chat_models import ( ChatGroq, @@ -256,6 +258,49 @@ def test_chat_groq_invalid_streaming_params() -> None: ) +def test_with_structured_output_json_schema_strict() -> None: + class Response(BaseModel): + """Response schema.""" + + foo: str + + structured_model = ChatGroq(model="openai/gpt-oss-20b").with_structured_output( + Response, method="json_schema", strict=True + ) + + assert isinstance(structured_model, RunnableSequence) + first_step = structured_model.steps[0] + assert isinstance(first_step, RunnableBinding) + response_format = first_step.kwargs["response_format"] + assert response_format["type"] == "json_schema" + json_schema = response_format["json_schema"] + assert json_schema["strict"] is True + assert json_schema["name"] == "Response" + assert json_schema["schema"]["properties"]["foo"]["type"] == "string" + assert "foo" in json_schema["schema"]["required"] + assert json_schema["schema"]["additionalProperties"] is False + + +def test_with_structured_output_json_schema_strict_ignored_on_unsupported_model() -> ( + None +): + class Response(BaseModel): + """Response schema.""" + + foo: str + + structured_model = ChatGroq(model="llama-3.1-8b-instant").with_structured_output( + Response, method="json_schema", strict=True + ) + + assert isinstance(structured_model, RunnableSequence) + first_step = structured_model.steps[0] + assert isinstance(first_step, RunnableBinding) + response_format = first_step.kwargs["response_format"] + assert response_format["type"] == "json_schema" + assert "strict" not in response_format["json_schema"] + + def test_chat_groq_secret() -> None: """Test that secret is not printed.""" secret = "secretKey" # noqa: S105