perf(core): memoize BaseTool.tool_call_schema subset model and cache model_json_schema (#38073)

This commit is contained in:
Nick Hollon
2026-06-17 17:17:14 -04:00
committed by GitHub
parent ae1c9418b5
commit 138727c008
3 changed files with 271 additions and 3 deletions

View File

@@ -9,7 +9,7 @@ import logging
import typing
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -28,6 +28,7 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
PydanticDeprecationWarning,
SkipValidation,
ValidationError,
@@ -37,7 +38,7 @@ from pydantic.fields import FieldInfo
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import override
from typing_extensions import Self, override
from langchain_core.callbacks import (
AsyncCallbackManager,
@@ -390,6 +391,39 @@ content is normalized to the content of a `ToolMessage` with `status="error"`.
_EMPTY_SET: frozenset[str] = frozenset()
_TOOL_CALL_SCHEMA_FIELDS = frozenset({"name", "description", "args_schema"})
"""Fields the memoized `tool_call_schema` is built from; reassignment clears it."""
def _patch_json_schema_cache(model_cls: type) -> None:
"""Patch `model_json_schema` (or `schema` for pydantic v1) to cache.
Pydantic regenerates the full JSON-schema dict on every
`model_json_schema()` call — there is no per-class cache. When the
model class is stable (memoized on a `BaseTool` instance), this patch
caches the dict on the class so repeated calls return instantly.
Only calls with all-default arguments are cached; any explicit arguments
bypass the cache and delegate to the original method.
"""
method_name = (
"model_json_schema" if hasattr(model_cls, "model_json_schema") else "schema"
)
orig = getattr(model_cls, method_name)
def _cached_json_schema(cls: type, *args: Any, **kwargs: Any) -> dict[str, Any]:
if not args and not kwargs:
cached = cls.__dict__.get("_json_schema_cache")
if cached is not None:
return cast("dict[str, Any]", cached)
result = orig(*args, **kwargs)
if not args and not kwargs:
cls._json_schema_cache = result # type: ignore[attr-defined]
return cast("dict[str, Any]", result)
setattr(model_cls, method_name, classmethod(_cached_json_schema))
class BaseTool(RunnableSerializable[str | dict[str, Any] | ToolCall, Any]):
"""Base class for all LangChain tools.
@@ -581,12 +615,70 @@ class ChildTool(BaseTool):
json_schema = model_json_schema(input_schema)
return cast("dict[str, Any]", json_schema["properties"])
_tool_call_schema_memo: ArgsSchema | None = PrivateAttr(default=None)
"""Memoized `tool_call_schema` result.
Building the subset model is expensive, and pydantic does not cache
`model_json_schema()` per class, so agent loops would otherwise pay full
schema generation for every tool on every model call. The subset model
class is memoized here and its `model_json_schema`/`schema` method is
patched to cache the generated dict, so both costs are paid only once per
tool instance.
Cleared whenever `name`, `description`, or `args_schema` is reassigned (see
`__setattr__` and `model_copy`).
"""
@override
def __setattr__(self, name: str, value: Any) -> None:
"""Clear the tool-call schema memo when an input to it is reassigned."""
super().__setattr__(name, value)
if name in _TOOL_CALL_SCHEMA_FIELDS and self.__pydantic_private__ is not None:
self._tool_call_schema_memo = None
@override
def model_copy(
self, *, update: Mapping[str, Any] | None = None, deep: bool = False
) -> Self:
"""Copy the tool, clearing the schema memo if `update` affects it.
`model_copy` writes `update` directly to the copy's `__dict__` without
going through `__setattr__`, and private attributes (including the
memo) carry over to the copy, so the memo is cleared here when the
update touches one of the fields the schema is built from.
"""
copied = super().model_copy(update=update, deep=deep)
if update and not _TOOL_CALL_SCHEMA_FIELDS.isdisjoint(update):
copied._tool_call_schema_memo = None # noqa: SLF001
return copied
def __getstate__(self) -> dict[Any, Any]:
"""Drop the tool-call schema memo when pickling.
The memoized subset model is a dynamically created class that cannot be
pickled by reference; it is rebuilt lazily on next access.
"""
state = super().__getstate__()
private = state.get("__pydantic_private__")
if private and private.get("_tool_call_schema_memo") is not None:
state = dict(state)
state["__pydantic_private__"] = {
**private,
"_tool_call_schema_memo": None,
}
return state
@property
def tool_call_schema(self) -> ArgsSchema:
"""Get the schema for tool calls, excluding injected arguments.
Returns:
The schema that should be used for tool calls from language models.
The returned model class is memoized per tool instance (invalidated
when `name`, `description`, or `args_schema` is reassigned) so
repeated access does not regenerate the class. The class's
`model_json_schema` method is also patched to cache the generated
schema dict, since pydantic does not cache it per class.
"""
if isinstance(self.args_schema, dict):
if self.description:
@@ -597,14 +689,20 @@ class ChildTool(BaseTool):
return self.args_schema
if (memo := self._tool_call_schema_memo) is not None:
return memo
full_schema = self.get_input_schema()
fields = []
for name, type_ in get_all_basemodel_annotations(full_schema).items():
if not _is_injected_arg_type(type_):
fields.append(name)
return _create_subset_model(
subset_model = _create_subset_model(
self.name, full_schema, fields, fn_description=self.description
)
_patch_json_schema_cache(subset_model)
self._tool_call_schema_memo = subset_model
return subset_model
@functools.cached_property
def _injected_args_keys(self) -> frozenset[str]:

View File

@@ -0,0 +1,69 @@
"""Benchmarks for tool-to-OpenAI schema conversion.
Agent loops convert every bound tool on every model call (token counting and
`bind_tools`), so conversion cost is paid per step. The warm benchmark measures
the steady state with the `tool_call_schema` memo populated; the cold benchmark
measures first-time conversion including subset-model creation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import pytest
from pydantic import BaseModel, Field, create_model
from langchain_core.tools import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
from pytest_benchmark.fixture import BenchmarkFixture
_NUM_TOOLS = 20
_NUM_FIELDS = 8
def _make_tools(num_tools: int) -> list[StructuredTool]:
tools = []
for i in range(num_tools):
fields: dict[str, Any] = {
f"param_{j}": (
str | None,
Field(default=None, description=f"Parameter {j} of action {i}."),
)
for j in range(_NUM_FIELDS)
}
fields["target"] = (str, Field(description="Primary target identifier."))
schema: type[BaseModel] = create_model(f"BenchTool{i}Input", **fields)
tools.append(
StructuredTool.from_function(
func=lambda **_kwargs: "ok",
name=f"bench_tool_{i}",
description=f"Benchmark tool {i} with several configurable options.",
args_schema=schema,
)
)
return tools
@pytest.mark.benchmark
def test_convert_tools_warm(benchmark: BenchmarkFixture) -> None:
"""Steady-state conversion of a reused toolset (memoized schema path)."""
tools = _make_tools(_NUM_TOOLS)
for tool in tools:
convert_to_openai_tool(tool) # populate the schema memo
@benchmark # type: ignore[untyped-decorator]
def convert_warm() -> None:
for tool in tools:
convert_to_openai_tool(tool)
@pytest.mark.benchmark
def test_convert_tools_cold(benchmark: BenchmarkFixture) -> None:
"""First-time conversion, including subset-model creation per tool."""
@benchmark # type: ignore[untyped-decorator]
def convert_cold() -> None:
for tool in _make_tools(_NUM_TOOLS):
convert_to_openai_tool(tool)

View File

@@ -3,6 +3,7 @@
import inspect
import json
import logging
import pickle
import sys
import textwrap
import threading
@@ -3932,3 +3933,103 @@ def test_tool_invoke_returns_list_of_mixin() -> None:
assert isinstance(result, list)
assert len(result) == 3
assert all(isinstance(m, ToolMessage) for m in result)
class _MemoSchemaInput(BaseModel):
query: str = Field(description="Query to run.")
class _MemoSchemaTool(BaseTool):
"""Module-level tool so pickling by class reference works."""
name: str = "memo_schema_tool"
description: str = "Tool for tool_call_schema memoization tests."
args_schema: type[BaseModel] = _MemoSchemaInput
def _run(self, query: str) -> str:
return query
def test_tool_call_schema_memoized_across_accesses() -> None:
tool = _MemoSchemaTool()
first = tool.tool_call_schema
assert tool.tool_call_schema is first
def test_tool_call_schema_memo_invalidated_on_reassignment() -> None:
tool = _MemoSchemaTool()
first = tool.tool_call_schema
tool.description = "Updated description."
second = tool.tool_call_schema
assert second is not first
assert (
cast("type[BaseModel]", second).model_json_schema()["description"]
== "Updated description."
)
class OtherInput(BaseModel):
other_field: str = Field(description="A different field.")
tool.args_schema = OtherInput
third = tool.tool_call_schema
assert third is not second
assert (
"other_field"
in cast("type[BaseModel]", third).model_json_schema()["properties"]
)
def test_tool_picklable_after_tool_call_schema_access() -> None:
"""The memoized schema is a dynamic class and must not leak into pickles."""
tool = _MemoSchemaTool()
pickle.loads(pickle.dumps(tool))
schema = tool.tool_call_schema
restored = pickle.loads(pickle.dumps(tool))
restored_schema = cast("type[BaseModel]", restored.tool_call_schema)
assert restored_schema.model_json_schema()["title"] == "memo_schema_tool"
# Pickling must not clear the live instance's memo.
assert tool.tool_call_schema is schema
def test_tool_call_schema_memo_not_stale_after_model_copy() -> None:
"""`model_copy(update=...)` bypasses `__setattr__`; the memo must still clear."""
tool = _MemoSchemaTool()
original_schema = tool.tool_call_schema
copied = tool.model_copy(update={"description": "Copied description."})
copied_schema = cast("type[BaseModel]", copied.tool_call_schema)
assert copied_schema.model_json_schema()["description"] == "Copied description."
# The original keeps its memo and is unaffected by the copy.
assert tool.tool_call_schema is original_schema
def test_tool_call_schema_json_schema_cached() -> None:
"""`model_json_schema()` is cached on the memoized subset model class."""
tool = _MemoSchemaTool()
schema_cls = cast("type[BaseModel]", tool.tool_call_schema)
first = schema_cls.model_json_schema()
second = schema_cls.model_json_schema()
assert first is second # same dict object, not a regenerate
# Non-default arguments bypass the cache and delegate to pydantic.
by_alias_off = schema_cls.model_json_schema(by_alias=False)
assert by_alias_off is not first
def test_tool_call_schema_json_schema_cache_invalidated_on_reassignment() -> None:
"""Reassigning an input field creates a fresh class with a fresh cache."""
tool = _MemoSchemaTool()
old_cls = cast("type[BaseModel]", tool.tool_call_schema)
old_schema = old_cls.model_json_schema()
tool.description = "New description for cache test."
new_cls = cast("type[BaseModel]", tool.tool_call_schema)
assert new_cls is not old_cls
new_schema = new_cls.model_json_schema()
assert new_schema is not old_schema
assert new_schema["description"] == "New description for cache test."