perf(core): cache _create_subset_model results with lru_cache

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Sydney Runkle
2026-04-21 10:01:39 -04:00
parent 0f042990e3
commit 3a36d4cd2f
2 changed files with 21 additions and 15 deletions

View File

@@ -199,18 +199,21 @@ class _IgnoreUnserializable(GenerateJsonSchema):
return {}
@lru_cache(maxsize=256)
def _create_subset_model_v1(
name: str,
model: type[BaseModelV1],
field_names: list,
field_names: tuple[str, ...],
*,
descriptions: dict | None = None,
descriptions: tuple[tuple[str, str], ...] | None = None,
fn_description: str | None = None,
) -> type[BaseModelV1]:
"""Create a Pydantic model with only a subset of model's fields."""
descriptions_ = dict(descriptions) if descriptions else {}
field_names_list = list(field_names)
fields = {}
for field_name in field_names:
for field_name in field_names_list:
# Using pydantic v1 so can access __fields__ as a dict.
field = model.__fields__[field_name]
t = (
@@ -219,8 +222,8 @@ def _create_subset_model_v1(
if field.required and not field.allow_none
else field.outer_type_ | None
)
if descriptions and field_name in descriptions:
field.field_info.description = descriptions[field_name]
if descriptions_ and field_name in descriptions_:
field.field_info.description = descriptions_[field_name]
fields[field_name] = (t, field.field_info)
rtn = cast("type[BaseModelV1]", create_model_v1(name, **fields)) # type: ignore[call-overload]
@@ -228,18 +231,20 @@ def _create_subset_model_v1(
return rtn
@lru_cache(maxsize=256)
def _create_subset_model_v2(
name: str,
model: type[BaseModel],
field_names: list[str],
field_names: tuple[str, ...],
*,
descriptions: dict | None = None,
descriptions: tuple[tuple[str, str], ...] | None = None,
fn_description: str | None = None,
) -> type[BaseModel]:
"""Create a Pydantic model with a subset of the model fields."""
descriptions_ = descriptions or {}
descriptions_ = dict(descriptions) if descriptions else {}
field_names_list = list(field_names)
fields = {}
for field_name in field_names:
for field_name in field_names_list:
field = model.model_fields[field_name]
description = descriptions_.get(field_name, field.description)
field_kwargs: dict[str, Any] = {"description": description}
@@ -291,19 +296,20 @@ def _create_subset_model(
Returns:
The created subset model.
"""
descriptions_tuple = tuple(descriptions.items()) if descriptions else None
if issubclass(model, BaseModelV1):
return _create_subset_model_v1(
name,
model,
field_names,
descriptions=descriptions,
tuple(field_names),
descriptions=descriptions_tuple,
fn_description=fn_description,
)
return _create_subset_model_v2(
name,
model,
field_names,
descriptions=descriptions,
tuple(field_names),
descriptions=descriptions_tuple,
fn_description=fn_description,
)

View File

@@ -115,7 +115,7 @@ def test_with_field_metadata() -> None:
description="List of integers", min_length=10, max_length=15
)
subset_model = _create_subset_model_v2("Foo", Foo, ["x"])
subset_model = _create_subset_model_v2("Foo", Foo, ("x",))
assert subset_model.model_json_schema() == {
"properties": {
"x": {
@@ -199,7 +199,7 @@ def test_create_subset_model_v2_preserves_default_factory() -> None:
subset = _create_subset_model_v2(
"Subset",
Original,
["required_field", "names", "mapping"],
("required_field", "names", "mapping"),
)
schema = subset.model_json_schema()
assert schema.get("required") == ["required_field"]