perf(core): cache tool_call_schema, args, and inferred input schema with invalidation

This commit is contained in:
Sydney Runkle
2026-04-30 10:36:53 -04:00
parent ffa515cadf
commit 9042ab9f4e
2 changed files with 101 additions and 2 deletions

View File

@@ -526,6 +526,19 @@ class ChildTool(BaseTool):
arbitrary_types_allowed=True,
)
_SCHEMA_INVALIDATING_FIELDS: frozenset[str] = frozenset(
{"args_schema", "name", "description"}
)
def __setattr__(self, name: str, value: object) -> None:
if name in self._SCHEMA_INVALIDATING_FIELDS:
self.__dict__.pop("tool_call_schema", None)
self.__dict__.pop("args", None)
self.__dict__.pop("_inferred_input_schema", None)
# _approximate_schema_chars (added in PR 6) is also invalidated here
self.__dict__.pop("_approximate_schema_chars", None)
super().__setattr__(name, value)
@property
def is_single_input(self) -> bool:
"""Check if the tool accepts only a single input argument.
@@ -536,7 +549,7 @@ class ChildTool(BaseTool):
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@property
@functools.cached_property
def args(self) -> dict:
"""Get the tool's input arguments schema.
@@ -555,7 +568,7 @@ class ChildTool(BaseTool):
json_schema = input_schema.model_json_schema()
return cast("dict", json_schema["properties"])
@property
@functools.cached_property
def tool_call_schema(self) -> ArgsSchema:
"""Get the schema for tool calls, excluding injected arguments.
@@ -601,6 +614,11 @@ class ChildTool(BaseTool):
if isinstance(self.args_schema, dict):
return super().get_input_schema(config)
return self.args_schema
return self._inferred_input_schema
@functools.cached_property
def _inferred_input_schema(self) -> type[BaseModel]:
"""Schema inferred from `_run` signature; computed once."""
return create_schema_from_function(self.name, self._run)
@override

View File

@@ -3780,6 +3780,87 @@ def test_parse_input_annotation_walk_called_once() -> None:
assert len(calls_for_my_input) <= 1
def test_tool_call_schema_is_cached() -> None:
"""tool_call_schema must return the same object on repeated access."""
from langchain_core.tools import tool
@tool
def my_tool(x: int) -> int:
"""A tool."""
return x
schema1 = my_tool.tool_call_schema
schema2 = my_tool.tool_call_schema
assert schema1 is schema2
def test_args_is_cached() -> None:
"""args must return the same object on repeated access."""
from langchain_core.tools import tool
@tool
def my_tool(x: int) -> int:
"""A tool."""
return x
args1 = my_tool.args
args2 = my_tool.args
assert args1 is args2
def test_tool_call_schema_invalidated_on_name_change() -> None:
"""Cache must be invalidated when `name` is mutated."""
from langchain_core.tools import tool
@tool
def my_tool(x: int) -> int:
"""A tool."""
return x
schema_before = my_tool.tool_call_schema
my_tool.name = "new_name"
schema_after = my_tool.tool_call_schema
assert schema_before is not schema_after
def test_tool_call_schema_invalidated_on_description_change() -> None:
"""Cache must be invalidated when `description` is mutated."""
from langchain_core.tools import tool
@tool
def my_tool(x: int) -> int:
"""A tool."""
return x
schema_before = my_tool.tool_call_schema
my_tool.description = "new description"
schema_after = my_tool.tool_call_schema
assert schema_before is not schema_after
def test_get_input_schema_cached() -> None:
"""get_input_schema must not call create_schema_from_function more than once."""
from unittest.mock import patch
from langchain_core.tools import tool
from langchain_core.tools import base as base_module
@tool
def my_tool(x: int) -> int:
"""A tool."""
return x
with patch.object(
base_module,
"create_schema_from_function",
wraps=base_module.create_schema_from_function,
) as mock_create:
my_tool.get_input_schema()
my_tool.get_input_schema()
my_tool.get_input_schema()
assert mock_create.call_count <= 1
def test_filter_injected_args_no_annotation_walk_on_run() -> None:
"""_filter_injected_args must not call get_all_basemodel_annotations on each run."""
from unittest.mock import patch