mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-17 10:34:27 +00:00
fix(core): preserve default_factory when generating tool call schema (#35550)
This commit is contained in:
committed by
GitHub
parent
e015fb2267
commit
b21c0a8062
@@ -242,7 +242,12 @@ def _create_subset_model_v2(
|
||||
for field_name in field_names:
|
||||
field = model.model_fields[field_name]
|
||||
description = descriptions_.get(field_name, field.description)
|
||||
field_info = FieldInfoV2(description=description, default=field.default)
|
||||
field_kwargs: dict[str, Any] = {"description": description}
|
||||
if field.default_factory is not None:
|
||||
field_kwargs["default_factory"] = field.default_factory
|
||||
else:
|
||||
field_kwargs["default"] = field.default
|
||||
field_info = FieldInfoV2(**field_kwargs)
|
||||
if field.metadata:
|
||||
field_info.metadata = field.metadata
|
||||
fields[field_name] = (field.annotation, field_info)
|
||||
|
||||
@@ -3636,3 +3636,20 @@ def test_tool_args_schema_falsy_defaults() -> None:
|
||||
# Invoke with only required argument - falsy defaults should be applied
|
||||
result = config_tool.invoke({"name": "test"})
|
||||
assert result == "name=test, enabled=False, count=0, prefix=''"
|
||||
|
||||
|
||||
def test_tool_default_factory_not_required() -> None:
|
||||
"""Fields with default_factory should not appear in required."""
|
||||
|
||||
class Args(BaseModel):
|
||||
"""Hello."""
|
||||
|
||||
names: list[str] = Field(default_factory=list, description="Some names")
|
||||
|
||||
@tool(args_schema=Args)
|
||||
def some_func(names: list[str] | None = None) -> None:
|
||||
"""Do something."""
|
||||
|
||||
schema = convert_to_openai_tool(some_func)
|
||||
params = schema["function"]["parameters"]
|
||||
assert "names" not in params.get("required", [])
|
||||
|
||||
@@ -186,3 +186,22 @@ def test_create_model_v2() -> None:
|
||||
foo.model_json_schema()
|
||||
|
||||
assert list(record) == []
|
||||
|
||||
|
||||
def test_create_subset_model_v2_preserves_default_factory() -> None:
|
||||
"""Fields with default_factory should not be marked as required."""
|
||||
|
||||
class Original(BaseModel):
|
||||
required_field: str
|
||||
names: list[str] = Field(default_factory=list, description="Some names")
|
||||
mapping: dict[str, int] = Field(default_factory=dict, description="A mapping")
|
||||
|
||||
subset = _create_subset_model_v2(
|
||||
"Subset",
|
||||
Original,
|
||||
["required_field", "names", "mapping"],
|
||||
)
|
||||
schema = subset.model_json_schema()
|
||||
assert schema.get("required") == ["required_field"]
|
||||
assert "names" not in schema.get("required", [])
|
||||
assert "mapping" not in schema.get("required", [])
|
||||
|
||||
Reference in New Issue
Block a user