diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index c16cffa1abf..be4c1ca19e7 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2196,44 +2196,32 @@ def _resize(width: int, height: int) -> Tuple[int, int]: def _update_schema_with_optional_fields(input_dict: dict) -> dict: - + """Convert optional fields to required fields allowing 'null' type.""" def _update_properties(schema: dict): - if schema.get('type') != 'object': + if schema.get("type") != "object": return - properties = schema.get('properties', {}) - required_fields = set(schema.get('required', [])) + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) for field, field_schema in properties.items(): - # Remove the 'default' key if it exists - field_schema.pop('default', None) + field_schema.pop("default", None) - if field_schema.get('type') == 'object': + if field_schema.get("type") == "object": _update_properties(field_schema) - if 'allOf' in field_schema: - for sub_schema in field_schema['allOf']: - if sub_schema.get('type') == 'object': - _update_properties(sub_schema) - - # if 'anyOf' in field_schema: - # # Check if 'anyOf' contains a 'null' type - # types = [sub_schema.get('type') for sub_schema in field_schema['anyOf']] - # if 'null' not in types: - # field_schema['anyOf'].append({'type': 'null'}) - if field not in required_fields: - original_type = field_schema.get('type') + original_type = field_schema.get("type") if isinstance(original_type, str): - field_schema['type'] = [original_type, 'null'] - elif isinstance(original_type, list) and 'null' not in original_type: - field_schema['type'].append('null') + field_schema["type"] = [original_type, "null"] + elif isinstance(original_type, list) and "null" not in original_type: + field_schema["type"].append("null") - required_fields.add(field) + required_fields.append(field) - schema['required'] = list(required_fields) + schema["required"] = required_fields - schema = input_dict.get('json_schema', {}).get('schema', {}) + schema = input_dict.get("json_schema", {}).get("schema", {}) _update_properties(schema) return input_dict diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 2e6cca0cd2d..d588cc82dd8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -18,7 +18,7 @@ from langchain_core.messages import ( ) from langchain_core.messages.ai import UsageMetadata from pydantic import BaseModel, Field -from typing_extensions import TypedDict +from typing_extensions import Annotated, TypedDict from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import ( @@ -822,6 +822,69 @@ def test__convert_to_openai_response_format() -> None: with pytest.raises(ValueError): _convert_to_openai_response_format(response_format, strict=False) + # Test handling of optional fields + ## TypedDict + class Entity(TypedDict): + """Extracted entity.""" + + animal: Annotated[str, ..., "The animal"] + color: Annotated[Optional[str], None, "The color"] + + actual = _convert_to_openai_response_format(Entity, strict=True) + expected = { + "type": "json_schema", + "json_schema": { + "name": "Entity", + "description": "Extracted entity.", + "strict": True, + "schema": { + "type": "object", + "properties": { + "animal": {"description": "The animal", "type": "string"}, + "color": {"description": "The color", "type": ["string", "null"]}, + }, + "required": ["animal", "color"], + "additionalProperties": False, + }, + }, + } + assert expected == actual + + ## JSON Schema + class Entity(BaseModel): + """Extracted entity.""" + + animal: str = Field(description="The animal") + color: Optional[str] = Field(default=None, description="The color") + + actual = _convert_to_openai_response_format(Entity.model_json_schema(), strict=True) + expected = { + "type": "json_schema", + "json_schema": { + "name": "Entity", + "description": "Extracted entity.", + "strict": True, + "schema": { + "properties": { + "animal": { + "description": "The animal", + "title": "Animal", + "type": "string", + }, + "color": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "description": "The color", + "title": "Color", + }, + }, + "required": ["animal", "color"], + "type": "object", + "additionalProperties": False, + }, + }, + } + assert expected == actual + @pytest.mark.parametrize("method", ["function_calling", "json_schema"]) @pytest.mark.parametrize("strict", [True, None])