perf(core): memoize get_all_basemodel_annotations with lru_cache

This commit is contained in:
Sydney Runkle
2026-04-30 10:26:39 -04:00
parent 7ff6c96539
commit 6c719df31d
2 changed files with 21 additions and 7 deletions

View File

@@ -1447,11 +1447,15 @@ def _is_injected_arg_type(
)
@functools.lru_cache(maxsize=512)
def get_all_basemodel_annotations(
cls: TypeBaseModel | Any, *, default_to_bound: bool = True
cls: TypeBaseModel | Any, default_to_bound: bool = True
) -> dict[str, type | TypeVar]:
"""Get all annotations from a Pydantic `BaseModel` and its parents.
The result is cached per `(cls, default_to_bound)` — callers must not mutate
the returned dict.
Args:
cls: The Pydantic `BaseModel` class.
default_to_bound: Whether to default to the bound of a `TypeVar` if it exists.
@@ -1475,9 +1479,7 @@ def get_all_basemodel_annotations(
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
# cls has subscript: cls = FooBar[int]
else:
annotations = get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False
)
annotations = dict(get_all_basemodel_annotations(get_origin(cls), False))
orig_bases = (cls,)
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
@@ -1488,9 +1490,7 @@ def get_all_basemodel_annotations(
for parent in orig_bases:
# if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update(
get_all_basemodel_annotations(parent, default_to_bound=False)
)
annotations.update(get_all_basemodel_annotations(parent, False))
continue
parent_origin = get_origin(parent)

View File

@@ -3748,3 +3748,17 @@ def test_get_filtered_args_removed() -> None:
import langchain_core.tools.base as base_module
assert not hasattr(base_module, "_get_filtered_args")
def test_get_all_basemodel_annotations_is_memoized() -> None:
"""Repeated calls with the same class must return the cached result (same object)."""
from langchain_core.tools.base import get_all_basemodel_annotations
from pydantic import BaseModel
class Foo(BaseModel):
x: int
y: str
result1 = get_all_basemodel_annotations(Foo)
result2 = get_all_basemodel_annotations(Foo)
assert result1 is result2, "Expected identical object from cache"