mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
update
This commit is contained in:
parent
d18e1c67ff
commit
ea0d921e72
@ -2196,44 +2196,32 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def _update_schema_with_optional_fields(input_dict: dict) -> dict:
|
||||
|
||||
"""Convert optional fields to required fields allowing 'null' type."""
|
||||
def _update_properties(schema: dict):
|
||||
if schema.get('type') != 'object':
|
||||
if schema.get("type") != "object":
|
||||
return
|
||||
|
||||
properties = schema.get('properties', {})
|
||||
required_fields = set(schema.get('required', []))
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
for field, field_schema in properties.items():
|
||||
# Remove the 'default' key if it exists
|
||||
field_schema.pop('default', None)
|
||||
field_schema.pop("default", None)
|
||||
|
||||
if field_schema.get('type') == 'object':
|
||||
if field_schema.get("type") == "object":
|
||||
_update_properties(field_schema)
|
||||
|
||||
if 'allOf' in field_schema:
|
||||
for sub_schema in field_schema['allOf']:
|
||||
if sub_schema.get('type') == 'object':
|
||||
_update_properties(sub_schema)
|
||||
|
||||
# if 'anyOf' in field_schema:
|
||||
# # Check if 'anyOf' contains a 'null' type
|
||||
# types = [sub_schema.get('type') for sub_schema in field_schema['anyOf']]
|
||||
# if 'null' not in types:
|
||||
# field_schema['anyOf'].append({'type': 'null'})
|
||||
|
||||
if field not in required_fields:
|
||||
original_type = field_schema.get('type')
|
||||
original_type = field_schema.get("type")
|
||||
if isinstance(original_type, str):
|
||||
field_schema['type'] = [original_type, 'null']
|
||||
elif isinstance(original_type, list) and 'null' not in original_type:
|
||||
field_schema['type'].append('null')
|
||||
field_schema["type"] = [original_type, "null"]
|
||||
elif isinstance(original_type, list) and "null" not in original_type:
|
||||
field_schema["type"].append("null")
|
||||
|
||||
required_fields.add(field)
|
||||
required_fields.append(field)
|
||||
|
||||
schema['required'] = list(required_fields)
|
||||
schema["required"] = required_fields
|
||||
|
||||
schema = input_dict.get('json_schema', {}).get('schema', {})
|
||||
schema = input_dict.get("json_schema", {}).get("schema", {})
|
||||
_update_properties(schema)
|
||||
|
||||
return input_dict
|
||||
|
@ -18,7 +18,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import (
|
||||
@ -822,6 +822,69 @@ def test__convert_to_openai_response_format() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_convert_to_openai_response_format(response_format, strict=False)
|
||||
|
||||
# Test handling of optional fields
|
||||
## TypedDict
|
||||
class Entity(TypedDict):
|
||||
"""Extracted entity."""
|
||||
|
||||
animal: Annotated[str, ..., "The animal"]
|
||||
color: Annotated[Optional[str], None, "The color"]
|
||||
|
||||
actual = _convert_to_openai_response_format(Entity, strict=True)
|
||||
expected = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "Entity",
|
||||
"description": "Extracted entity.",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"animal": {"description": "The animal", "type": "string"},
|
||||
"color": {"description": "The color", "type": ["string", "null"]},
|
||||
},
|
||||
"required": ["animal", "color"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert expected == actual
|
||||
|
||||
## JSON Schema
|
||||
class Entity(BaseModel):
|
||||
"""Extracted entity."""
|
||||
|
||||
animal: str = Field(description="The animal")
|
||||
color: Optional[str] = Field(default=None, description="The color")
|
||||
|
||||
actual = _convert_to_openai_response_format(Entity.model_json_schema(), strict=True)
|
||||
expected = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "Entity",
|
||||
"description": "Extracted entity.",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"properties": {
|
||||
"animal": {
|
||||
"description": "The animal",
|
||||
"title": "Animal",
|
||||
"type": "string",
|
||||
},
|
||||
"color": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "The color",
|
||||
"title": "Color",
|
||||
},
|
||||
},
|
||||
"required": ["animal", "color"],
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["function_calling", "json_schema"])
|
||||
@pytest.mark.parametrize("strict", [True, None])
|
||||
|
Loading…
Reference in New Issue
Block a user