mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-15 03:25:21 +00:00
perf(core): cache tool_call_schema, args, and inferred input schema with invalidation
This commit is contained in:
@@ -526,6 +526,19 @@ class ChildTool(BaseTool):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
_SCHEMA_INVALIDATING_FIELDS: frozenset[str] = frozenset(
|
||||
{"args_schema", "name", "description"}
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._SCHEMA_INVALIDATING_FIELDS:
|
||||
self.__dict__.pop("tool_call_schema", None)
|
||||
self.__dict__.pop("args", None)
|
||||
self.__dict__.pop("_inferred_input_schema", None)
|
||||
# _approximate_schema_chars (added in PR 6) is also invalidated here
|
||||
self.__dict__.pop("_approximate_schema_chars", None)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Check if the tool accepts only a single input argument.
|
||||
@@ -536,7 +549,7 @@ class ChildTool(BaseTool):
|
||||
keys = {k for k in self.args if k != "kwargs"}
|
||||
return len(keys) == 1
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def args(self) -> dict:
|
||||
"""Get the tool's input arguments schema.
|
||||
|
||||
@@ -555,7 +568,7 @@ class ChildTool(BaseTool):
|
||||
json_schema = input_schema.model_json_schema()
|
||||
return cast("dict", json_schema["properties"])
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def tool_call_schema(self) -> ArgsSchema:
|
||||
"""Get the schema for tool calls, excluding injected arguments.
|
||||
|
||||
@@ -601,6 +614,11 @@ class ChildTool(BaseTool):
|
||||
if isinstance(self.args_schema, dict):
|
||||
return super().get_input_schema(config)
|
||||
return self.args_schema
|
||||
return self._inferred_input_schema
|
||||
|
||||
@functools.cached_property
|
||||
def _inferred_input_schema(self) -> type[BaseModel]:
|
||||
"""Schema inferred from `_run` signature; computed once."""
|
||||
return create_schema_from_function(self.name, self._run)
|
||||
|
||||
@override
|
||||
|
||||
@@ -3780,6 +3780,87 @@ def test_parse_input_annotation_walk_called_once() -> None:
|
||||
assert len(calls_for_my_input) <= 1
|
||||
|
||||
|
||||
def test_tool_call_schema_is_cached() -> None:
|
||||
"""tool_call_schema must return the same object on repeated access."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def my_tool(x: int) -> int:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
schema1 = my_tool.tool_call_schema
|
||||
schema2 = my_tool.tool_call_schema
|
||||
assert schema1 is schema2
|
||||
|
||||
|
||||
def test_args_is_cached() -> None:
|
||||
"""args must return the same object on repeated access."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def my_tool(x: int) -> int:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
args1 = my_tool.args
|
||||
args2 = my_tool.args
|
||||
assert args1 is args2
|
||||
|
||||
|
||||
def test_tool_call_schema_invalidated_on_name_change() -> None:
|
||||
"""Cache must be invalidated when `name` is mutated."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def my_tool(x: int) -> int:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
schema_before = my_tool.tool_call_schema
|
||||
my_tool.name = "new_name"
|
||||
schema_after = my_tool.tool_call_schema
|
||||
assert schema_before is not schema_after
|
||||
|
||||
|
||||
def test_tool_call_schema_invalidated_on_description_change() -> None:
|
||||
"""Cache must be invalidated when `description` is mutated."""
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def my_tool(x: int) -> int:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
schema_before = my_tool.tool_call_schema
|
||||
my_tool.description = "new description"
|
||||
schema_after = my_tool.tool_call_schema
|
||||
assert schema_before is not schema_after
|
||||
|
||||
|
||||
def test_get_input_schema_cached() -> None:
|
||||
"""get_input_schema must not call create_schema_from_function more than once."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools import base as base_module
|
||||
|
||||
@tool
|
||||
def my_tool(x: int) -> int:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
with patch.object(
|
||||
base_module,
|
||||
"create_schema_from_function",
|
||||
wraps=base_module.create_schema_from_function,
|
||||
) as mock_create:
|
||||
my_tool.get_input_schema()
|
||||
my_tool.get_input_schema()
|
||||
my_tool.get_input_schema()
|
||||
assert mock_create.call_count <= 1
|
||||
|
||||
|
||||
def test_filter_injected_args_no_annotation_walk_on_run() -> None:
|
||||
"""_filter_injected_args must not call get_all_basemodel_annotations on each run."""
|
||||
from unittest.mock import patch
|
||||
|
||||
Reference in New Issue
Block a user