mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-17 04:45:11 +00:00
perf(core): memoize get_all_basemodel_annotations with lru_cache
This commit is contained in:
@@ -1447,11 +1447,15 @@ def _is_injected_arg_type(
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def get_all_basemodel_annotations(
|
||||
cls: TypeBaseModel | Any, *, default_to_bound: bool = True
|
||||
cls: TypeBaseModel | Any, default_to_bound: bool = True
|
||||
) -> dict[str, type | TypeVar]:
|
||||
"""Get all annotations from a Pydantic `BaseModel` and its parents.
|
||||
|
||||
The result is cached per `(cls, default_to_bound)` — callers must not mutate
|
||||
the returned dict.
|
||||
|
||||
Args:
|
||||
cls: The Pydantic `BaseModel` class.
|
||||
default_to_bound: Whether to default to the bound of a `TypeVar` if it exists.
|
||||
@@ -1475,9 +1479,7 @@ def get_all_basemodel_annotations(
|
||||
orig_bases: tuple = getattr(cls, "__orig_bases__", ())
|
||||
# cls has subscript: cls = FooBar[int]
|
||||
else:
|
||||
annotations = get_all_basemodel_annotations(
|
||||
get_origin(cls), default_to_bound=False
|
||||
)
|
||||
annotations = dict(get_all_basemodel_annotations(get_origin(cls), False))
|
||||
orig_bases = (cls,)
|
||||
|
||||
# Pydantic v2 automatically resolves inherited generics, Pydantic v1 does not.
|
||||
@@ -1488,9 +1490,7 @@ def get_all_basemodel_annotations(
|
||||
for parent in orig_bases:
|
||||
# if class = FooBar inherits from Baz, parent = Baz
|
||||
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
|
||||
annotations.update(
|
||||
get_all_basemodel_annotations(parent, default_to_bound=False)
|
||||
)
|
||||
annotations.update(get_all_basemodel_annotations(parent, False))
|
||||
continue
|
||||
|
||||
parent_origin = get_origin(parent)
|
||||
|
||||
@@ -3748,3 +3748,17 @@ def test_get_filtered_args_removed() -> None:
|
||||
import langchain_core.tools.base as base_module
|
||||
|
||||
assert not hasattr(base_module, "_get_filtered_args")
|
||||
|
||||
|
||||
def test_get_all_basemodel_annotations_is_memoized() -> None:
|
||||
"""Repeated calls with the same class must return the cached result (same object)."""
|
||||
from langchain_core.tools.base import get_all_basemodel_annotations
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Foo(BaseModel):
|
||||
x: int
|
||||
y: str
|
||||
|
||||
result1 = get_all_basemodel_annotations(Foo)
|
||||
result2 = get_all_basemodel_annotations(Foo)
|
||||
assert result1 is result2, "Expected identical object from cache"
|
||||
|
||||
Reference in New Issue
Block a user