diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a8..c16cffa1abf 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2195,6 +2195,50 @@ def _resize(width: int, height: int) -> Tuple[int, int]: return width, height +def _update_schema_with_optional_fields(input_dict: dict) -> dict: + + def _update_properties(schema: dict): + if schema.get('type') != 'object': + return + + properties = schema.get('properties', {}) + required_fields = set(schema.get('required', [])) + + for field, field_schema in properties.items(): + # Remove the 'default' key if it exists + field_schema.pop('default', None) + + if field_schema.get('type') == 'object': + _update_properties(field_schema) + + if 'allOf' in field_schema: + for sub_schema in field_schema['allOf']: + if sub_schema.get('type') == 'object': + _update_properties(sub_schema) + + # if 'anyOf' in field_schema: + # # Check if 'anyOf' contains a 'null' type + # types = [sub_schema.get('type') for sub_schema in field_schema['anyOf']] + # if 'null' not in types: + # field_schema['anyOf'].append({'type': 'null'}) + + if field not in required_fields: + original_type = field_schema.get('type') + if isinstance(original_type, str): + field_schema['type'] = [original_type, 'null'] + elif isinstance(original_type, list) and 'null' not in original_type: + field_schema['type'].append('null') + + required_fields.add(field) + + schema['required'] = list(required_fields) + + schema = input_dict.get('json_schema', {}).get('schema', {}) + _update_properties(schema) + + return input_dict + + def _convert_to_openai_response_format( schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None ) -> Union[Dict, TypeBaseModel]: @@ -2225,6 +2269,8 @@ def _convert_to_openai_response_format( f"'strict' is only specified in one place." ) raise ValueError(msg) + if strict: + _update_schema_with_optional_fields(response_format) return response_format