This commit is contained in:
Chester Curme 2025-01-07 14:31:39 -05:00
parent d18e1c67ff
commit ea0d921e72
2 changed files with 77 additions and 26 deletions

View File

@ -2196,44 +2196,32 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
def _update_schema_with_optional_fields(input_dict: dict) -> dict:
"""Convert optional fields to required fields allowing 'null' type."""
def _update_properties(schema: dict):
if schema.get('type') != 'object':
if schema.get("type") != "object":
return
properties = schema.get('properties', {})
required_fields = set(schema.get('required', []))
properties = schema.get("properties", {})
required_fields = schema.get("required", [])
for field, field_schema in properties.items():
# Remove the 'default' key if it exists
field_schema.pop('default', None)
field_schema.pop("default", None)
if field_schema.get('type') == 'object':
if field_schema.get("type") == "object":
_update_properties(field_schema)
if 'allOf' in field_schema:
for sub_schema in field_schema['allOf']:
if sub_schema.get('type') == 'object':
_update_properties(sub_schema)
# if 'anyOf' in field_schema:
# # Check if 'anyOf' contains a 'null' type
# types = [sub_schema.get('type') for sub_schema in field_schema['anyOf']]
# if 'null' not in types:
# field_schema['anyOf'].append({'type': 'null'})
if field not in required_fields:
original_type = field_schema.get('type')
original_type = field_schema.get("type")
if isinstance(original_type, str):
field_schema['type'] = [original_type, 'null']
elif isinstance(original_type, list) and 'null' not in original_type:
field_schema['type'].append('null')
field_schema["type"] = [original_type, "null"]
elif isinstance(original_type, list) and "null" not in original_type:
field_schema["type"].append("null")
required_fields.add(field)
required_fields.append(field)
schema['required'] = list(required_fields)
schema["required"] = required_fields
schema = input_dict.get('json_schema', {}).get('schema', {})
schema = input_dict.get("json_schema", {}).get("schema", {})
_update_properties(schema)
return input_dict

View File

@ -18,7 +18,7 @@ from langchain_core.messages import (
)
from langchain_core.messages.ai import UsageMetadata
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from typing_extensions import Annotated, TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
@ -822,6 +822,69 @@ def test__convert_to_openai_response_format() -> None:
with pytest.raises(ValueError):
_convert_to_openai_response_format(response_format, strict=False)
# Test handling of optional fields
## TypedDict
class Entity(TypedDict):
"""Extracted entity."""
animal: Annotated[str, ..., "The animal"]
color: Annotated[Optional[str], None, "The color"]
actual = _convert_to_openai_response_format(Entity, strict=True)
expected = {
"type": "json_schema",
"json_schema": {
"name": "Entity",
"description": "Extracted entity.",
"strict": True,
"schema": {
"type": "object",
"properties": {
"animal": {"description": "The animal", "type": "string"},
"color": {"description": "The color", "type": ["string", "null"]},
},
"required": ["animal", "color"],
"additionalProperties": False,
},
},
}
assert expected == actual
## JSON Schema
class Entity(BaseModel):
"""Extracted entity."""
animal: str = Field(description="The animal")
color: Optional[str] = Field(default=None, description="The color")
actual = _convert_to_openai_response_format(Entity.model_json_schema(), strict=True)
expected = {
"type": "json_schema",
"json_schema": {
"name": "Entity",
"description": "Extracted entity.",
"strict": True,
"schema": {
"properties": {
"animal": {
"description": "The animal",
"title": "Animal",
"type": "string",
},
"color": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"description": "The color",
"title": "Color",
},
},
"required": ["animal", "color"],
"type": "object",
"additionalProperties": False,
},
},
}
assert expected == actual
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
@pytest.mark.parametrize("strict", [True, None])