From 138727c008bec135ff8474f194278bf9c2cef09d Mon Sep 17 00:00:00 2001 From: Nick Hollon Date: Wed, 17 Jun 2026 17:17:14 -0400 Subject: [PATCH] perf(core): memoize `BaseTool.tool_call_schema` subset model and cache `model_json_schema` (#38073) --- libs/core/langchain_core/tools/base.py | 104 +++++++++++++++++- .../benchmarks/test_tool_schema_conversion.py | 69 ++++++++++++ libs/core/tests/unit_tests/test_tools.py | 101 +++++++++++++++++ 3 files changed, 271 insertions(+), 3 deletions(-) create mode 100644 libs/core/tests/benchmarks/test_tool_schema_conversion.py diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 4b26e1b39fd..2bf1f13187d 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -9,7 +9,7 @@ import logging import typing import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from inspect import signature from typing import ( TYPE_CHECKING, @@ -28,6 +28,7 @@ from pydantic import ( BaseModel, ConfigDict, Field, + PrivateAttr, PydanticDeprecationWarning, SkipValidation, ValidationError, @@ -37,7 +38,7 @@ from pydantic.fields import FieldInfo from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 -from typing_extensions import override +from typing_extensions import Self, override from langchain_core.callbacks import ( AsyncCallbackManager, @@ -390,6 +391,39 @@ content is normalized to the content of a `ToolMessage` with `status="error"`. _EMPTY_SET: frozenset[str] = frozenset() +_TOOL_CALL_SCHEMA_FIELDS = frozenset({"name", "description", "args_schema"}) +"""Fields the memoized `tool_call_schema` is built from; reassignment clears it.""" + + +def _patch_json_schema_cache(model_cls: type) -> None: + """Patch `model_json_schema` (or `schema` for pydantic v1) to cache. + + Pydantic regenerates the full JSON-schema dict on every + `model_json_schema()` call — there is no per-class cache. When the + model class is stable (memoized on a `BaseTool` instance), this patch + caches the dict on the class so repeated calls return instantly. + + Only calls with all-default arguments are cached; any explicit arguments + bypass the cache and delegate to the original method. + """ + method_name = ( + "model_json_schema" if hasattr(model_cls, "model_json_schema") else "schema" + ) + orig = getattr(model_cls, method_name) + + def _cached_json_schema(cls: type, *args: Any, **kwargs: Any) -> dict[str, Any]: + if not args and not kwargs: + cached = cls.__dict__.get("_json_schema_cache") + if cached is not None: + return cast("dict[str, Any]", cached) + result = orig(*args, **kwargs) + if not args and not kwargs: + cls._json_schema_cache = result # type: ignore[attr-defined] + return cast("dict[str, Any]", result) + + setattr(model_cls, method_name, classmethod(_cached_json_schema)) + + class BaseTool(RunnableSerializable[str | dict[str, Any] | ToolCall, Any]): """Base class for all LangChain tools. @@ -581,12 +615,70 @@ class ChildTool(BaseTool): json_schema = model_json_schema(input_schema) return cast("dict[str, Any]", json_schema["properties"]) + _tool_call_schema_memo: ArgsSchema | None = PrivateAttr(default=None) + """Memoized `tool_call_schema` result. + + Building the subset model is expensive, and pydantic does not cache + `model_json_schema()` per class, so agent loops would otherwise pay full + schema generation for every tool on every model call. The subset model + class is memoized here and its `model_json_schema`/`schema` method is + patched to cache the generated dict, so both costs are paid only once per + tool instance. + Cleared whenever `name`, `description`, or `args_schema` is reassigned (see + `__setattr__` and `model_copy`). + """ + + @override + def __setattr__(self, name: str, value: Any) -> None: + """Clear the tool-call schema memo when an input to it is reassigned.""" + super().__setattr__(name, value) + if name in _TOOL_CALL_SCHEMA_FIELDS and self.__pydantic_private__ is not None: + self._tool_call_schema_memo = None + + @override + def model_copy( + self, *, update: Mapping[str, Any] | None = None, deep: bool = False + ) -> Self: + """Copy the tool, clearing the schema memo if `update` affects it. + + `model_copy` writes `update` directly to the copy's `__dict__` without + going through `__setattr__`, and private attributes (including the + memo) carry over to the copy, so the memo is cleared here when the + update touches one of the fields the schema is built from. + """ + copied = super().model_copy(update=update, deep=deep) + if update and not _TOOL_CALL_SCHEMA_FIELDS.isdisjoint(update): + copied._tool_call_schema_memo = None # noqa: SLF001 + return copied + + def __getstate__(self) -> dict[Any, Any]: + """Drop the tool-call schema memo when pickling. + + The memoized subset model is a dynamically created class that cannot be + pickled by reference; it is rebuilt lazily on next access. + """ + state = super().__getstate__() + private = state.get("__pydantic_private__") + if private and private.get("_tool_call_schema_memo") is not None: + state = dict(state) + state["__pydantic_private__"] = { + **private, + "_tool_call_schema_memo": None, + } + return state + @property def tool_call_schema(self) -> ArgsSchema: """Get the schema for tool calls, excluding injected arguments. Returns: The schema that should be used for tool calls from language models. + + The returned model class is memoized per tool instance (invalidated + when `name`, `description`, or `args_schema` is reassigned) so + repeated access does not regenerate the class. The class's + `model_json_schema` method is also patched to cache the generated + schema dict, since pydantic does not cache it per class. """ if isinstance(self.args_schema, dict): if self.description: @@ -597,14 +689,20 @@ class ChildTool(BaseTool): return self.args_schema + if (memo := self._tool_call_schema_memo) is not None: + return memo + full_schema = self.get_input_schema() fields = [] for name, type_ in get_all_basemodel_annotations(full_schema).items(): if not _is_injected_arg_type(type_): fields.append(name) - return _create_subset_model( + subset_model = _create_subset_model( self.name, full_schema, fields, fn_description=self.description ) + _patch_json_schema_cache(subset_model) + self._tool_call_schema_memo = subset_model + return subset_model @functools.cached_property def _injected_args_keys(self) -> frozenset[str]: diff --git a/libs/core/tests/benchmarks/test_tool_schema_conversion.py b/libs/core/tests/benchmarks/test_tool_schema_conversion.py new file mode 100644 index 00000000000..0657e5a637a --- /dev/null +++ b/libs/core/tests/benchmarks/test_tool_schema_conversion.py @@ -0,0 +1,69 @@ +"""Benchmarks for tool-to-OpenAI schema conversion. + +Agent loops convert every bound tool on every model call (token counting and +`bind_tools`), so conversion cost is paid per step. The warm benchmark measures +the steady state with the `tool_call_schema` memo populated; the cold benchmark +measures first-time conversion including subset-model creation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +from pydantic import BaseModel, Field, create_model + +from langchain_core.tools import StructuredTool +from langchain_core.utils.function_calling import convert_to_openai_tool + +if TYPE_CHECKING: + from pytest_benchmark.fixture import BenchmarkFixture + +_NUM_TOOLS = 20 +_NUM_FIELDS = 8 + + +def _make_tools(num_tools: int) -> list[StructuredTool]: + tools = [] + for i in range(num_tools): + fields: dict[str, Any] = { + f"param_{j}": ( + str | None, + Field(default=None, description=f"Parameter {j} of action {i}."), + ) + for j in range(_NUM_FIELDS) + } + fields["target"] = (str, Field(description="Primary target identifier.")) + schema: type[BaseModel] = create_model(f"BenchTool{i}Input", **fields) + tools.append( + StructuredTool.from_function( + func=lambda **_kwargs: "ok", + name=f"bench_tool_{i}", + description=f"Benchmark tool {i} with several configurable options.", + args_schema=schema, + ) + ) + return tools + + +@pytest.mark.benchmark +def test_convert_tools_warm(benchmark: BenchmarkFixture) -> None: + """Steady-state conversion of a reused toolset (memoized schema path).""" + tools = _make_tools(_NUM_TOOLS) + for tool in tools: + convert_to_openai_tool(tool) # populate the schema memo + + @benchmark # type: ignore[untyped-decorator] + def convert_warm() -> None: + for tool in tools: + convert_to_openai_tool(tool) + + +@pytest.mark.benchmark +def test_convert_tools_cold(benchmark: BenchmarkFixture) -> None: + """First-time conversion, including subset-model creation per tool.""" + + @benchmark # type: ignore[untyped-decorator] + def convert_cold() -> None: + for tool in _make_tools(_NUM_TOOLS): + convert_to_openai_tool(tool) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 140da9cdd74..2c0a94ff8ce 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -3,6 +3,7 @@ import inspect import json import logging +import pickle import sys import textwrap import threading @@ -3932,3 +3933,103 @@ def test_tool_invoke_returns_list_of_mixin() -> None: assert isinstance(result, list) assert len(result) == 3 assert all(isinstance(m, ToolMessage) for m in result) + + +class _MemoSchemaInput(BaseModel): + query: str = Field(description="Query to run.") + + +class _MemoSchemaTool(BaseTool): + """Module-level tool so pickling by class reference works.""" + + name: str = "memo_schema_tool" + description: str = "Tool for tool_call_schema memoization tests." + args_schema: type[BaseModel] = _MemoSchemaInput + + def _run(self, query: str) -> str: + return query + + +def test_tool_call_schema_memoized_across_accesses() -> None: + tool = _MemoSchemaTool() + first = tool.tool_call_schema + assert tool.tool_call_schema is first + + +def test_tool_call_schema_memo_invalidated_on_reassignment() -> None: + tool = _MemoSchemaTool() + first = tool.tool_call_schema + + tool.description = "Updated description." + second = tool.tool_call_schema + assert second is not first + assert ( + cast("type[BaseModel]", second).model_json_schema()["description"] + == "Updated description." + ) + + class OtherInput(BaseModel): + other_field: str = Field(description="A different field.") + + tool.args_schema = OtherInput + third = tool.tool_call_schema + assert third is not second + assert ( + "other_field" + in cast("type[BaseModel]", third).model_json_schema()["properties"] + ) + + +def test_tool_picklable_after_tool_call_schema_access() -> None: + """The memoized schema is a dynamic class and must not leak into pickles.""" + tool = _MemoSchemaTool() + pickle.loads(pickle.dumps(tool)) + + schema = tool.tool_call_schema + restored = pickle.loads(pickle.dumps(tool)) + restored_schema = cast("type[BaseModel]", restored.tool_call_schema) + assert restored_schema.model_json_schema()["title"] == "memo_schema_tool" + # Pickling must not clear the live instance's memo. + assert tool.tool_call_schema is schema + + +def test_tool_call_schema_memo_not_stale_after_model_copy() -> None: + """`model_copy(update=...)` bypasses `__setattr__`; the memo must still clear.""" + tool = _MemoSchemaTool() + original_schema = tool.tool_call_schema + + copied = tool.model_copy(update={"description": "Copied description."}) + copied_schema = cast("type[BaseModel]", copied.tool_call_schema) + assert copied_schema.model_json_schema()["description"] == "Copied description." + + # The original keeps its memo and is unaffected by the copy. + assert tool.tool_call_schema is original_schema + + +def test_tool_call_schema_json_schema_cached() -> None: + """`model_json_schema()` is cached on the memoized subset model class.""" + tool = _MemoSchemaTool() + schema_cls = cast("type[BaseModel]", tool.tool_call_schema) + + first = schema_cls.model_json_schema() + second = schema_cls.model_json_schema() + assert first is second # same dict object, not a regenerate + + # Non-default arguments bypass the cache and delegate to pydantic. + by_alias_off = schema_cls.model_json_schema(by_alias=False) + assert by_alias_off is not first + + +def test_tool_call_schema_json_schema_cache_invalidated_on_reassignment() -> None: + """Reassigning an input field creates a fresh class with a fresh cache.""" + tool = _MemoSchemaTool() + old_cls = cast("type[BaseModel]", tool.tool_call_schema) + old_schema = old_cls.model_json_schema() + + tool.description = "New description for cache test." + new_cls = cast("type[BaseModel]", tool.tool_call_schema) + assert new_cls is not old_cls + + new_schema = new_cls.model_json_schema() + assert new_schema is not old_schema + assert new_schema["description"] == "New description for cache test."