diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index da90992eff0..e7ad919ddbb 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -526,6 +526,19 @@ class ChildTool(BaseTool): arbitrary_types_allowed=True, ) + _SCHEMA_INVALIDATING_FIELDS: frozenset[str] = frozenset( + {"args_schema", "name", "description"} + ) + + def __setattr__(self, name: str, value: object) -> None: + if name in self._SCHEMA_INVALIDATING_FIELDS: + self.__dict__.pop("tool_call_schema", None) + self.__dict__.pop("args", None) + self.__dict__.pop("_inferred_input_schema", None) + # _approximate_schema_chars (added in PR 6) is also invalidated here + self.__dict__.pop("_approximate_schema_chars", None) + super().__setattr__(name, value) + @property def is_single_input(self) -> bool: """Check if the tool accepts only a single input argument. @@ -536,7 +549,7 @@ class ChildTool(BaseTool): keys = {k for k in self.args if k != "kwargs"} return len(keys) == 1 - @property + @functools.cached_property def args(self) -> dict: """Get the tool's input arguments schema. @@ -555,7 +568,7 @@ class ChildTool(BaseTool): json_schema = input_schema.model_json_schema() return cast("dict", json_schema["properties"]) - @property + @functools.cached_property def tool_call_schema(self) -> ArgsSchema: """Get the schema for tool calls, excluding injected arguments. @@ -601,6 +614,11 @@ class ChildTool(BaseTool): if isinstance(self.args_schema, dict): return super().get_input_schema(config) return self.args_schema + return self._inferred_input_schema + + @functools.cached_property + def _inferred_input_schema(self) -> type[BaseModel]: + """Schema inferred from `_run` signature; computed once.""" return create_schema_from_function(self.name, self._run) @override diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 4e224e1726f..af6a3ffdcc2 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -3780,6 +3780,87 @@ def test_parse_input_annotation_walk_called_once() -> None: assert len(calls_for_my_input) <= 1 +def test_tool_call_schema_is_cached() -> None: + """tool_call_schema must return the same object on repeated access.""" + from langchain_core.tools import tool + + @tool + def my_tool(x: int) -> int: + """A tool.""" + return x + + schema1 = my_tool.tool_call_schema + schema2 = my_tool.tool_call_schema + assert schema1 is schema2 + + +def test_args_is_cached() -> None: + """args must return the same object on repeated access.""" + from langchain_core.tools import tool + + @tool + def my_tool(x: int) -> int: + """A tool.""" + return x + + args1 = my_tool.args + args2 = my_tool.args + assert args1 is args2 + + +def test_tool_call_schema_invalidated_on_name_change() -> None: + """Cache must be invalidated when `name` is mutated.""" + from langchain_core.tools import tool + + @tool + def my_tool(x: int) -> int: + """A tool.""" + return x + + schema_before = my_tool.tool_call_schema + my_tool.name = "new_name" + schema_after = my_tool.tool_call_schema + assert schema_before is not schema_after + + +def test_tool_call_schema_invalidated_on_description_change() -> None: + """Cache must be invalidated when `description` is mutated.""" + from langchain_core.tools import tool + + @tool + def my_tool(x: int) -> int: + """A tool.""" + return x + + schema_before = my_tool.tool_call_schema + my_tool.description = "new description" + schema_after = my_tool.tool_call_schema + assert schema_before is not schema_after + + +def test_get_input_schema_cached() -> None: + """get_input_schema must not call create_schema_from_function more than once.""" + from unittest.mock import patch + + from langchain_core.tools import tool + from langchain_core.tools import base as base_module + + @tool + def my_tool(x: int) -> int: + """A tool.""" + return x + + with patch.object( + base_module, + "create_schema_from_function", + wraps=base_module.create_schema_from_function, + ) as mock_create: + my_tool.get_input_schema() + my_tool.get_input_schema() + my_tool.get_input_schema() + assert mock_create.call_count <= 1 + + def test_filter_injected_args_no_annotation_walk_on_run() -> None: """_filter_injected_args must not call get_all_basemodel_annotations on each run.""" from unittest.mock import patch