feat(langchain,openai): add strict flag to ProviderStrategy structured output (#34149)

This commit is contained in:
Towseef Altaf
2025-12-11 02:05:23 +05:30
committed by GitHub
parent 69dd39c461
commit d27fb0c432
3 changed files with 32 additions and 8 deletions

View File

@@ -255,21 +255,32 @@ class ProviderStrategy(Generic[SchemaT]):
def __init__( def __init__(
self, self,
schema: type[SchemaT], schema: type[SchemaT],
*,
strict: bool = False,
) -> None: ) -> None:
"""Initialize ProviderStrategy with schema.""" """Initialize ProviderStrategy with schema.
Args:
schema: Schema to enforce via the provider's native structured output.
strict: Whether to request strict provider-side schema enforcement.
"""
self.schema = schema self.schema = schema
self.schema_spec = _SchemaSpec(schema) self.schema_spec = _SchemaSpec(schema, strict=strict)
def to_model_kwargs(self) -> dict[str, Any]: def to_model_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs to bind to a model to force structured output.""" """Convert to kwargs to bind to a model to force structured output."""
# OpenAI: # OpenAI:
# - see https://platform.openai.com/docs/guides/structured-outputs # - see https://platform.openai.com/docs/guides/structured-outputs
response_format = { json_schema: dict[str, Any] = {
"name": self.schema_spec.name,
"schema": self.schema_spec.json_schema,
}
if self.schema_spec.strict:
json_schema["strict"] = True
response_format: dict[str, Any] = {
"type": "json_schema", "type": "json_schema",
"json_schema": { "json_schema": json_schema,
"name": self.schema_spec.name,
"schema": self.schema_spec.json_schema,
},
} }
return {"response_format": response_format} return {"response_format": response_format}

View File

@@ -737,6 +737,18 @@ class TestResponseFormatAsProviderStrategy:
assert response["structured_response"] == EXPECTED_WEATHER_DICT assert response["structured_response"] == EXPECTED_WEATHER_DICT
assert len(response["messages"]) == 4 assert len(response["messages"]) == 4
def test_provider_strategy_strict_flag(self) -> None:
"""ProviderStrategy should pass through strict flag for provider schemas."""
# Default should not set strict
strategy_default = ProviderStrategy(WeatherBaseModel)
kwargs_default = strategy_default.to_model_kwargs()
assert "strict" not in kwargs_default["response_format"]["json_schema"]
# Explicit strict True should include the flag
strategy_strict = ProviderStrategy(WeatherBaseModel, strict=True)
kwargs_strict = strategy_strict.to_model_kwargs()
assert kwargs_strict["response_format"]["json_schema"]["strict"] is True
class TestDynamicModelWithResponseFormat: class TestDynamicModelWithResponseFormat:
"""Test response_format with middleware that modifies the model.""" """Test response_format with middleware that modifies the model."""

View File

@@ -1886,9 +1886,10 @@ class BaseChatOpenAI(BaseChatModel):
): ):
# compat with langchain.agents.create_agent response_format, which is # compat with langchain.agents.create_agent response_format, which is
# an approximation of OpenAI format # an approximation of OpenAI format
strict = response_format["json_schema"].get("strict", None)
response_format = cast(dict, response_format["json_schema"]["schema"]) response_format = cast(dict, response_format["json_schema"]["schema"])
kwargs["response_format"] = _convert_to_openai_response_format( kwargs["response_format"] = _convert_to_openai_response_format(
response_format response_format, strict=strict
) )
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)