mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
feat(langchain,openai): add strict flag to ProviderStrategy structured output (#34149)
This commit is contained in:
@@ -255,21 +255,32 @@ class ProviderStrategy(Generic[SchemaT]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
schema: type[SchemaT],
|
schema: type[SchemaT],
|
||||||
|
*,
|
||||||
|
strict: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize ProviderStrategy with schema."""
|
"""Initialize ProviderStrategy with schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: Schema to enforce via the provider's native structured output.
|
||||||
|
strict: Whether to request strict provider-side schema enforcement.
|
||||||
|
"""
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.schema_spec = _SchemaSpec(schema)
|
self.schema_spec = _SchemaSpec(schema, strict=strict)
|
||||||
|
|
||||||
def to_model_kwargs(self) -> dict[str, Any]:
|
def to_model_kwargs(self) -> dict[str, Any]:
|
||||||
"""Convert to kwargs to bind to a model to force structured output."""
|
"""Convert to kwargs to bind to a model to force structured output."""
|
||||||
# OpenAI:
|
# OpenAI:
|
||||||
# - see https://platform.openai.com/docs/guides/structured-outputs
|
# - see https://platform.openai.com/docs/guides/structured-outputs
|
||||||
response_format = {
|
json_schema: dict[str, Any] = {
|
||||||
|
"name": self.schema_spec.name,
|
||||||
|
"schema": self.schema_spec.json_schema,
|
||||||
|
}
|
||||||
|
if self.schema_spec.strict:
|
||||||
|
json_schema["strict"] = True
|
||||||
|
|
||||||
|
response_format: dict[str, Any] = {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": json_schema,
|
||||||
"name": self.schema_spec.name,
|
|
||||||
"schema": self.schema_spec.json_schema,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
return {"response_format": response_format}
|
return {"response_format": response_format}
|
||||||
|
|
||||||
|
|||||||
@@ -737,6 +737,18 @@ class TestResponseFormatAsProviderStrategy:
|
|||||||
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
assert response["structured_response"] == EXPECTED_WEATHER_DICT
|
||||||
assert len(response["messages"]) == 4
|
assert len(response["messages"]) == 4
|
||||||
|
|
||||||
|
def test_provider_strategy_strict_flag(self) -> None:
|
||||||
|
"""ProviderStrategy should pass through strict flag for provider schemas."""
|
||||||
|
# Default should not set strict
|
||||||
|
strategy_default = ProviderStrategy(WeatherBaseModel)
|
||||||
|
kwargs_default = strategy_default.to_model_kwargs()
|
||||||
|
assert "strict" not in kwargs_default["response_format"]["json_schema"]
|
||||||
|
|
||||||
|
# Explicit strict True should include the flag
|
||||||
|
strategy_strict = ProviderStrategy(WeatherBaseModel, strict=True)
|
||||||
|
kwargs_strict = strategy_strict.to_model_kwargs()
|
||||||
|
assert kwargs_strict["response_format"]["json_schema"]["strict"] is True
|
||||||
|
|
||||||
|
|
||||||
class TestDynamicModelWithResponseFormat:
|
class TestDynamicModelWithResponseFormat:
|
||||||
"""Test response_format with middleware that modifies the model."""
|
"""Test response_format with middleware that modifies the model."""
|
||||||
|
|||||||
@@ -1886,9 +1886,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
):
|
):
|
||||||
# compat with langchain.agents.create_agent response_format, which is
|
# compat with langchain.agents.create_agent response_format, which is
|
||||||
# an approximation of OpenAI format
|
# an approximation of OpenAI format
|
||||||
|
strict = response_format["json_schema"].get("strict", None)
|
||||||
response_format = cast(dict, response_format["json_schema"]["schema"])
|
response_format = cast(dict, response_format["json_schema"]["schema"])
|
||||||
kwargs["response_format"] = _convert_to_openai_response_format(
|
kwargs["response_format"] = _convert_to_openai_response_format(
|
||||||
response_format
|
response_format, strict=strict
|
||||||
)
|
)
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user