mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
perf(core): memoize BaseTool.tool_call_schema subset model and cache model_json_schema (#38073)
This commit is contained in:
@@ -9,7 +9,7 @@ import logging
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -28,6 +28,7 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
PydanticDeprecationWarning,
|
||||
SkipValidation,
|
||||
ValidationError,
|
||||
@@ -37,7 +38,7 @@ from pydantic.fields import FieldInfo
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||
from typing_extensions import override
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
@@ -390,6 +391,39 @@ content is normalized to the content of a `ToolMessage` with `status="error"`.
|
||||
_EMPTY_SET: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
_TOOL_CALL_SCHEMA_FIELDS = frozenset({"name", "description", "args_schema"})
|
||||
"""Fields the memoized `tool_call_schema` is built from; reassignment clears it."""
|
||||
|
||||
|
||||
def _patch_json_schema_cache(model_cls: type) -> None:
|
||||
"""Patch `model_json_schema` (or `schema` for pydantic v1) to cache.
|
||||
|
||||
Pydantic regenerates the full JSON-schema dict on every
|
||||
`model_json_schema()` call — there is no per-class cache. When the
|
||||
model class is stable (memoized on a `BaseTool` instance), this patch
|
||||
caches the dict on the class so repeated calls return instantly.
|
||||
|
||||
Only calls with all-default arguments are cached; any explicit arguments
|
||||
bypass the cache and delegate to the original method.
|
||||
"""
|
||||
method_name = (
|
||||
"model_json_schema" if hasattr(model_cls, "model_json_schema") else "schema"
|
||||
)
|
||||
orig = getattr(model_cls, method_name)
|
||||
|
||||
def _cached_json_schema(cls: type, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
if not args and not kwargs:
|
||||
cached = cls.__dict__.get("_json_schema_cache")
|
||||
if cached is not None:
|
||||
return cast("dict[str, Any]", cached)
|
||||
result = orig(*args, **kwargs)
|
||||
if not args and not kwargs:
|
||||
cls._json_schema_cache = result # type: ignore[attr-defined]
|
||||
return cast("dict[str, Any]", result)
|
||||
|
||||
setattr(model_cls, method_name, classmethod(_cached_json_schema))
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[str | dict[str, Any] | ToolCall, Any]):
|
||||
"""Base class for all LangChain tools.
|
||||
|
||||
@@ -581,12 +615,70 @@ class ChildTool(BaseTool):
|
||||
json_schema = model_json_schema(input_schema)
|
||||
return cast("dict[str, Any]", json_schema["properties"])
|
||||
|
||||
_tool_call_schema_memo: ArgsSchema | None = PrivateAttr(default=None)
|
||||
"""Memoized `tool_call_schema` result.
|
||||
|
||||
Building the subset model is expensive, and pydantic does not cache
|
||||
`model_json_schema()` per class, so agent loops would otherwise pay full
|
||||
schema generation for every tool on every model call. The subset model
|
||||
class is memoized here and its `model_json_schema`/`schema` method is
|
||||
patched to cache the generated dict, so both costs are paid only once per
|
||||
tool instance.
|
||||
Cleared whenever `name`, `description`, or `args_schema` is reassigned (see
|
||||
`__setattr__` and `model_copy`).
|
||||
"""
|
||||
|
||||
@override
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Clear the tool-call schema memo when an input to it is reassigned."""
|
||||
super().__setattr__(name, value)
|
||||
if name in _TOOL_CALL_SCHEMA_FIELDS and self.__pydantic_private__ is not None:
|
||||
self._tool_call_schema_memo = None
|
||||
|
||||
@override
|
||||
def model_copy(
|
||||
self, *, update: Mapping[str, Any] | None = None, deep: bool = False
|
||||
) -> Self:
|
||||
"""Copy the tool, clearing the schema memo if `update` affects it.
|
||||
|
||||
`model_copy` writes `update` directly to the copy's `__dict__` without
|
||||
going through `__setattr__`, and private attributes (including the
|
||||
memo) carry over to the copy, so the memo is cleared here when the
|
||||
update touches one of the fields the schema is built from.
|
||||
"""
|
||||
copied = super().model_copy(update=update, deep=deep)
|
||||
if update and not _TOOL_CALL_SCHEMA_FIELDS.isdisjoint(update):
|
||||
copied._tool_call_schema_memo = None # noqa: SLF001
|
||||
return copied
|
||||
|
||||
def __getstate__(self) -> dict[Any, Any]:
|
||||
"""Drop the tool-call schema memo when pickling.
|
||||
|
||||
The memoized subset model is a dynamically created class that cannot be
|
||||
pickled by reference; it is rebuilt lazily on next access.
|
||||
"""
|
||||
state = super().__getstate__()
|
||||
private = state.get("__pydantic_private__")
|
||||
if private and private.get("_tool_call_schema_memo") is not None:
|
||||
state = dict(state)
|
||||
state["__pydantic_private__"] = {
|
||||
**private,
|
||||
"_tool_call_schema_memo": None,
|
||||
}
|
||||
return state
|
||||
|
||||
@property
|
||||
def tool_call_schema(self) -> ArgsSchema:
|
||||
"""Get the schema for tool calls, excluding injected arguments.
|
||||
|
||||
Returns:
|
||||
The schema that should be used for tool calls from language models.
|
||||
|
||||
The returned model class is memoized per tool instance (invalidated
|
||||
when `name`, `description`, or `args_schema` is reassigned) so
|
||||
repeated access does not regenerate the class. The class's
|
||||
`model_json_schema` method is also patched to cache the generated
|
||||
schema dict, since pydantic does not cache it per class.
|
||||
"""
|
||||
if isinstance(self.args_schema, dict):
|
||||
if self.description:
|
||||
@@ -597,14 +689,20 @@ class ChildTool(BaseTool):
|
||||
|
||||
return self.args_schema
|
||||
|
||||
if (memo := self._tool_call_schema_memo) is not None:
|
||||
return memo
|
||||
|
||||
full_schema = self.get_input_schema()
|
||||
fields = []
|
||||
for name, type_ in get_all_basemodel_annotations(full_schema).items():
|
||||
if not _is_injected_arg_type(type_):
|
||||
fields.append(name)
|
||||
return _create_subset_model(
|
||||
subset_model = _create_subset_model(
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
_patch_json_schema_cache(subset_model)
|
||||
self._tool_call_schema_memo = subset_model
|
||||
return subset_model
|
||||
|
||||
@functools.cached_property
|
||||
def _injected_args_keys(self) -> frozenset[str]:
|
||||
|
||||
69
libs/core/tests/benchmarks/test_tool_schema_conversion.py
Normal file
69
libs/core/tests/benchmarks/test_tool_schema_conversion.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Benchmarks for tool-to-OpenAI schema conversion.
|
||||
|
||||
Agent loops convert every bound tool on every model call (token counting and
|
||||
`bind_tools`), so conversion cost is paid per step. The warm benchmark measures
|
||||
the steady state with the `tool_call_schema` memo populated; the cold benchmark
|
||||
measures first-time conversion including subset-model creation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytest_benchmark.fixture import BenchmarkFixture
|
||||
|
||||
_NUM_TOOLS = 20
|
||||
_NUM_FIELDS = 8
|
||||
|
||||
|
||||
def _make_tools(num_tools: int) -> list[StructuredTool]:
|
||||
tools = []
|
||||
for i in range(num_tools):
|
||||
fields: dict[str, Any] = {
|
||||
f"param_{j}": (
|
||||
str | None,
|
||||
Field(default=None, description=f"Parameter {j} of action {i}."),
|
||||
)
|
||||
for j in range(_NUM_FIELDS)
|
||||
}
|
||||
fields["target"] = (str, Field(description="Primary target identifier."))
|
||||
schema: type[BaseModel] = create_model(f"BenchTool{i}Input", **fields)
|
||||
tools.append(
|
||||
StructuredTool.from_function(
|
||||
func=lambda **_kwargs: "ok",
|
||||
name=f"bench_tool_{i}",
|
||||
description=f"Benchmark tool {i} with several configurable options.",
|
||||
args_schema=schema,
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_convert_tools_warm(benchmark: BenchmarkFixture) -> None:
|
||||
"""Steady-state conversion of a reused toolset (memoized schema path)."""
|
||||
tools = _make_tools(_NUM_TOOLS)
|
||||
for tool in tools:
|
||||
convert_to_openai_tool(tool) # populate the schema memo
|
||||
|
||||
@benchmark # type: ignore[untyped-decorator]
|
||||
def convert_warm() -> None:
|
||||
for tool in tools:
|
||||
convert_to_openai_tool(tool)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_convert_tools_cold(benchmark: BenchmarkFixture) -> None:
|
||||
"""First-time conversion, including subset-model creation per tool."""
|
||||
|
||||
@benchmark # type: ignore[untyped-decorator]
|
||||
def convert_cold() -> None:
|
||||
for tool in _make_tools(_NUM_TOOLS):
|
||||
convert_to_openai_tool(tool)
|
||||
@@ -3,6 +3,7 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
@@ -3932,3 +3933,103 @@ def test_tool_invoke_returns_list_of_mixin() -> None:
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(m, ToolMessage) for m in result)
|
||||
|
||||
|
||||
class _MemoSchemaInput(BaseModel):
|
||||
query: str = Field(description="Query to run.")
|
||||
|
||||
|
||||
class _MemoSchemaTool(BaseTool):
|
||||
"""Module-level tool so pickling by class reference works."""
|
||||
|
||||
name: str = "memo_schema_tool"
|
||||
description: str = "Tool for tool_call_schema memoization tests."
|
||||
args_schema: type[BaseModel] = _MemoSchemaInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return query
|
||||
|
||||
|
||||
def test_tool_call_schema_memoized_across_accesses() -> None:
|
||||
tool = _MemoSchemaTool()
|
||||
first = tool.tool_call_schema
|
||||
assert tool.tool_call_schema is first
|
||||
|
||||
|
||||
def test_tool_call_schema_memo_invalidated_on_reassignment() -> None:
|
||||
tool = _MemoSchemaTool()
|
||||
first = tool.tool_call_schema
|
||||
|
||||
tool.description = "Updated description."
|
||||
second = tool.tool_call_schema
|
||||
assert second is not first
|
||||
assert (
|
||||
cast("type[BaseModel]", second).model_json_schema()["description"]
|
||||
== "Updated description."
|
||||
)
|
||||
|
||||
class OtherInput(BaseModel):
|
||||
other_field: str = Field(description="A different field.")
|
||||
|
||||
tool.args_schema = OtherInput
|
||||
third = tool.tool_call_schema
|
||||
assert third is not second
|
||||
assert (
|
||||
"other_field"
|
||||
in cast("type[BaseModel]", third).model_json_schema()["properties"]
|
||||
)
|
||||
|
||||
|
||||
def test_tool_picklable_after_tool_call_schema_access() -> None:
|
||||
"""The memoized schema is a dynamic class and must not leak into pickles."""
|
||||
tool = _MemoSchemaTool()
|
||||
pickle.loads(pickle.dumps(tool))
|
||||
|
||||
schema = tool.tool_call_schema
|
||||
restored = pickle.loads(pickle.dumps(tool))
|
||||
restored_schema = cast("type[BaseModel]", restored.tool_call_schema)
|
||||
assert restored_schema.model_json_schema()["title"] == "memo_schema_tool"
|
||||
# Pickling must not clear the live instance's memo.
|
||||
assert tool.tool_call_schema is schema
|
||||
|
||||
|
||||
def test_tool_call_schema_memo_not_stale_after_model_copy() -> None:
|
||||
"""`model_copy(update=...)` bypasses `__setattr__`; the memo must still clear."""
|
||||
tool = _MemoSchemaTool()
|
||||
original_schema = tool.tool_call_schema
|
||||
|
||||
copied = tool.model_copy(update={"description": "Copied description."})
|
||||
copied_schema = cast("type[BaseModel]", copied.tool_call_schema)
|
||||
assert copied_schema.model_json_schema()["description"] == "Copied description."
|
||||
|
||||
# The original keeps its memo and is unaffected by the copy.
|
||||
assert tool.tool_call_schema is original_schema
|
||||
|
||||
|
||||
def test_tool_call_schema_json_schema_cached() -> None:
|
||||
"""`model_json_schema()` is cached on the memoized subset model class."""
|
||||
tool = _MemoSchemaTool()
|
||||
schema_cls = cast("type[BaseModel]", tool.tool_call_schema)
|
||||
|
||||
first = schema_cls.model_json_schema()
|
||||
second = schema_cls.model_json_schema()
|
||||
assert first is second # same dict object, not a regenerate
|
||||
|
||||
# Non-default arguments bypass the cache and delegate to pydantic.
|
||||
by_alias_off = schema_cls.model_json_schema(by_alias=False)
|
||||
assert by_alias_off is not first
|
||||
|
||||
|
||||
def test_tool_call_schema_json_schema_cache_invalidated_on_reassignment() -> None:
|
||||
"""Reassigning an input field creates a fresh class with a fresh cache."""
|
||||
tool = _MemoSchemaTool()
|
||||
old_cls = cast("type[BaseModel]", tool.tool_call_schema)
|
||||
old_schema = old_cls.model_json_schema()
|
||||
|
||||
tool.description = "New description for cache test."
|
||||
new_cls = cast("type[BaseModel]", tool.tool_call_schema)
|
||||
assert new_cls is not old_cls
|
||||
|
||||
new_schema = new_cls.model_json_schema()
|
||||
assert new_schema is not old_schema
|
||||
assert new_schema["description"] == "New description for cache test."
|
||||
|
||||
Reference in New Issue
Block a user