Compare commits

...

4 Commits

Author SHA1 Message Date
Chester Curme
3c8328e9b7 lint 2025-01-07 15:00:23 -05:00
Chester Curme
ea0d921e72 update 2025-01-07 14:34:12 -05:00
Chester Curme
d18e1c67ff add test 2025-01-07 14:26:17 -05:00
Chester Curme
c318ab6152 update 2025-01-07 12:34:10 -05:00
3 changed files with 119 additions and 1 deletions

View File

@@ -2195,6 +2195,39 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
return width, height
def _update_schema_with_optional_fields(input_dict: dict) -> dict:
"""Convert optional fields to required fields allowing 'null' type."""
def _update_properties(schema: dict) -> None:
if schema.get("type") != "object":
return
properties = schema.get("properties", {})
required_fields = schema.get("required", [])
for field, field_schema in properties.items():
field_schema.pop("default", None)
if field_schema.get("type") == "object":
_update_properties(field_schema)
if field not in required_fields:
original_type = field_schema.get("type")
if isinstance(original_type, str):
field_schema["type"] = [original_type, "null"]
elif isinstance(original_type, list) and "null" not in original_type:
field_schema["type"].append("null")
required_fields.append(field)
schema["required"] = required_fields
schema = input_dict.get("json_schema", {}).get("schema", {})
_update_properties(schema)
return input_dict
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Union[Dict, TypeBaseModel]:
@@ -2225,6 +2258,8 @@ def _convert_to_openai_response_format(
f"'strict' is only specified in one place."
)
raise ValueError(msg)
if strict:
_update_schema_with_optional_fields(response_format)
return response_format

View File

@@ -18,7 +18,7 @@ from langchain_core.messages import (
)
from langchain_core.messages.ai import UsageMetadata
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from typing_extensions import Annotated, TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
@@ -822,6 +822,71 @@ def test__convert_to_openai_response_format() -> None:
with pytest.raises(ValueError):
_convert_to_openai_response_format(response_format, strict=False)
# Test handling of optional fields
## TypedDict
class Entity(TypedDict):
"""Extracted entity."""
animal: Annotated[str, ..., "The animal"]
color: Annotated[Optional[str], None, "The color"]
actual = _convert_to_openai_response_format(Entity, strict=True)
expected = {
"type": "json_schema",
"json_schema": {
"name": "Entity",
"description": "Extracted entity.",
"strict": True,
"schema": {
"type": "object",
"properties": {
"animal": {"description": "The animal", "type": "string"},
"color": {"description": "The color", "type": ["string", "null"]},
},
"required": ["animal", "color"],
"additionalProperties": False,
},
},
}
assert expected == actual
## JSON Schema
class EntityModel(BaseModel):
"""Extracted entity."""
animal: str = Field(description="The animal")
color: Optional[str] = Field(default=None, description="The color")
actual = _convert_to_openai_response_format(
EntityModel.model_json_schema(), strict=True
)
expected = {
"type": "json_schema",
"json_schema": {
"name": "EntityModel",
"description": "Extracted entity.",
"strict": True,
"schema": {
"properties": {
"animal": {
"description": "The animal",
"title": "Animal",
"type": "string",
},
"color": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"description": "The color",
"title": "Color",
},
},
"required": ["animal", "color"],
"type": "object",
"additionalProperties": False,
},
},
}
assert expected == actual
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
@pytest.mark.parametrize("strict", [True, None])

View File

@@ -21,6 +21,7 @@ from langchain_core.utils.function_calling import tool_example_to_messages
from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1
from typing_extensions import Annotated, TypedDict
from langchain_tests.unit_tests.chat_models import (
ChatModelTests,
@@ -1293,6 +1294,7 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
# Pydantic
class Joke(BaseModel):
"""Joke to tell user."""
@@ -1310,6 +1312,22 @@ class ChatModelIntegrationTests(ChatModelTests):
joke_result = chat.invoke("Give me a joke about cats, include the punchline.")
assert isinstance(joke_result, Joke)
# Schema
chat = model.with_structured_output(Joke.model_json_schema())
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
# TypedDict
class JokeDict(TypedDict):
"""Joke to tell user."""
setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[Optional[str], None, "answer to resolve the joke"]
chat = model.with_structured_output(JokeDict)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
def test_json_mode(self, model: BaseChatModel) -> None:
"""Test structured output via `JSON mode. <https://python.langchain.com/docs/concepts/structured_outputs/#json-mode>`_