From 76dd656f2a15935c8dbfcc93acbbcd61a91b0661 Mon Sep 17 00:00:00 2001
From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Date: Wed, 29 Oct 2025 15:20:53 -0700
Subject: [PATCH] fix: filter out injected args from tracing (#33729)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
this is CC generated and I want to do a thorough review + update the
tests. but should be able to ship today.
before eek
now, woo
---
libs/core/langchain_core/tools/base.py | 78 +++++--
libs/core/tests/unit_tests/test_tools.py | 253 +++++++++++++++++++++++
2 files changed, 319 insertions(+), 12 deletions(-)
diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py
index def3b227089..588d71efdc8 100644
--- a/libs/core/langchain_core/tools/base.py
+++ b/libs/core/langchain_core/tools/base.py
@@ -707,6 +707,35 @@ class ChildTool(BaseTool):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
+ def _filter_injected_args(self, tool_input: dict) -> dict:
+ """Filter out injected tool arguments from the input dictionary.
+
+ Injected arguments are those annotated with InjectedToolArg or its
+ subclasses, or arguments in FILTERED_ARGS like run_manager and callbacks.
+
+ Args:
+ tool_input: The tool input dictionary to filter.
+
+ Returns:
+ A filtered dictionary with injected arguments removed.
+ """
+ # Start with filtered args from the constant
+ filtered_keys = set[str](FILTERED_ARGS)
+
+ # If we have an args_schema, use it to identify injected args
+ if self.args_schema is not None:
+ try:
+ annotations = get_all_basemodel_annotations(self.args_schema)
+ for field_name, field_type in annotations.items():
+ if _is_injected_arg_type(field_type):
+ filtered_keys.add(field_name)
+ except Exception: # noqa: S110
+ # If we can't get annotations, just use FILTERED_ARGS
+ pass
+
+ # Filter out the injected keys from tool_input
+ return {k: v for k, v in tool_input.items() if k not in filtered_keys}
+
def _to_args_and_kwargs(
self, tool_input: str | dict, tool_call_id: str | None
) -> tuple[tuple, dict]:
@@ -794,17 +823,29 @@ class ChildTool(BaseTool):
self.metadata,
)
+ # Filter out injected arguments from callback inputs
+ filtered_tool_input = (
+ self._filter_injected_args(tool_input)
+ if isinstance(tool_input, dict)
+ else None
+ )
+
+ # Use filtered inputs for the input_str parameter as well
+ tool_input_str = (
+ tool_input
+ if isinstance(tool_input, str)
+ else str(
+ filtered_tool_input if filtered_tool_input is not None else tool_input
+ )
+ )
+
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
- tool_input if isinstance(tool_input, str) else str(tool_input),
+ tool_input_str,
color=start_color,
name=run_name,
run_id=run_id,
- # Inputs by definition should always be dicts.
- # For now, it's unclear whether this assumption is ever violated,
- # but if it is we will send a `None` value to the callback instead
- # TODO: will need to address issue via a patch.
- inputs=tool_input if isinstance(tool_input, dict) else None,
+ inputs=filtered_tool_input,
**kwargs,
)
@@ -905,17 +946,30 @@ class ChildTool(BaseTool):
metadata,
self.metadata,
)
+
+ # Filter out injected arguments from callback inputs
+ filtered_tool_input = (
+ self._filter_injected_args(tool_input)
+ if isinstance(tool_input, dict)
+ else None
+ )
+
+ # Use filtered inputs for the input_str parameter as well
+ tool_input_str = (
+ tool_input
+ if isinstance(tool_input, str)
+ else str(
+ filtered_tool_input if filtered_tool_input is not None else tool_input
+ )
+ )
+
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
- tool_input if isinstance(tool_input, str) else str(tool_input),
+ tool_input_str,
color=start_color,
name=run_name,
run_id=run_id,
- # Inputs by definition should always be dicts.
- # For now, it's unclear whether this assumption is ever violated,
- # but if it is we will send a `None` value to the callback instead
- # TODO: will need to address issue via a patch.
- inputs=tool_input if isinstance(tool_input, dict) else None,
+ inputs=filtered_tool_input,
**kwargs,
)
content = None
diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py
index 9fdd2afc819..b6e6020e841 100644
--- a/libs/core/tests/unit_tests/test_tools.py
+++ b/libs/core/tests/unit_tests/test_tools.py
@@ -70,6 +70,13 @@ from langchain_core.utils.pydantic import (
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
+try:
+ from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
+
+ HAS_LANGGRAPH = True
+except ImportError:
+ HAS_LANGGRAPH = False
+
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
tool_schema = tool.tool_call_schema
@@ -2773,3 +2780,249 @@ def test_tool_args_schema_with_annotated_type() -> None:
"type": "array",
}
}
+
+
+class CallbackHandlerWithInputCapture(FakeCallbackHandler):
+ """Callback handler that captures inputs passed to on_tool_start."""
+
+ captured_inputs: list[dict | None] = []
+
+ def on_tool_start(
+ self,
+ serialized: dict[str, Any],
+ input_str: str,
+ *,
+ run_id: Any,
+ parent_run_id: Any | None = None,
+ tags: list[str] | None = None,
+ metadata: dict[str, Any] | None = None,
+ inputs: dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ """Capture the inputs passed to on_tool_start."""
+ self.captured_inputs.append(inputs)
+ return super().on_tool_start(
+ serialized,
+ input_str,
+ run_id=run_id,
+ parent_run_id=parent_run_id,
+ tags=tags,
+ metadata=metadata,
+ inputs=inputs,
+ **kwargs,
+ )
+
+
+def test_filter_injected_args_from_callbacks() -> None:
+ """Test that injected tool arguments are filtered from callback inputs."""
+
+ @tool
+ def search_tool(
+ query: str,
+ state: Annotated[dict, InjectedToolArg()],
+ ) -> str:
+ """Search with injected state.
+
+ Args:
+ query: The search query.
+ state: Injected state context.
+ """
+ return f"Results for: {query}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+ result = search_tool.invoke(
+ {"query": "test query", "state": {"user_id": 123}},
+ config={"callbacks": [handler]},
+ )
+
+ assert result == "Results for: test query"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # Verify that injected 'state' arg is filtered out
+ captured = handler.captured_inputs[0]
+ assert captured is not None
+ assert "query" in captured
+ assert "state" not in captured
+ assert captured["query"] == "test query"
+
+
+def test_filter_run_manager_from_callbacks() -> None:
+ """Test that run_manager is filtered from callback inputs."""
+
+ @tool
+ def tool_with_run_manager(
+ message: str,
+ run_manager: CallbackManagerForToolRun | None = None,
+ ) -> str:
+ """Tool with run_manager parameter.
+
+ Args:
+ message: The message to process.
+ run_manager: The callback manager.
+ """
+ return f"Processed: {message}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+ result = tool_with_run_manager.invoke(
+ {"message": "hello"},
+ config={"callbacks": [handler]},
+ )
+
+ assert result == "Processed: hello"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # Verify that run_manager is filtered out
+ captured = handler.captured_inputs[0]
+ assert captured is not None
+ assert "message" in captured
+ assert "run_manager" not in captured
+
+
+def test_filter_multiple_injected_args() -> None:
+ """Test filtering multiple injected arguments from callback inputs."""
+
+ @tool
+ def complex_tool(
+ query: str,
+ limit: int,
+ state: Annotated[dict, InjectedToolArg()],
+ context: Annotated[str, InjectedToolArg()],
+ run_manager: CallbackManagerForToolRun | None = None,
+ ) -> str:
+ """Complex tool with multiple injected args.
+
+ Args:
+ query: The search query.
+ limit: Maximum number of results.
+ state: Injected state.
+ context: Injected context.
+ run_manager: The callback manager.
+ """
+ return f"Query: {query}, Limit: {limit}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+ result = complex_tool.invoke(
+ {
+ "query": "test",
+ "limit": 10,
+ "state": {"foo": "bar"},
+ "context": "some context",
+ },
+ config={"callbacks": [handler]},
+ )
+
+ assert result == "Query: test, Limit: 10"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # Verify that only non-injected args remain
+ captured = handler.captured_inputs[0]
+ assert captured is not None
+ assert captured == {"query": "test", "limit": 10}
+ assert "state" not in captured
+ assert "context" not in captured
+ assert "run_manager" not in captured
+
+
+def test_no_filtering_for_string_input() -> None:
+ """Test that string inputs are not filtered (passed as None)."""
+
+ @tool
+ def simple_tool(query: str) -> str:
+ """Simple tool with string input.
+
+ Args:
+ query: The query string.
+ """
+ return f"Result: {query}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+ result = simple_tool.invoke("test query", config={"callbacks": [handler]})
+
+ assert result == "Result: test query"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # String inputs should result in None for the inputs parameter
+ assert handler.captured_inputs[0] is None
+
+
+async def test_filter_injected_args_async() -> None:
+ """Test that injected args are filtered in async tool execution."""
+
+ @tool
+ async def async_search_tool(
+ query: str,
+ state: Annotated[dict, InjectedToolArg()],
+ ) -> str:
+ """Async search with injected state.
+
+ Args:
+ query: The search query.
+ state: Injected state context.
+ """
+ return f"Async results for: {query}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+ result = await async_search_tool.ainvoke(
+ {"query": "async test", "state": {"user_id": 456}},
+ config={"callbacks": [handler]},
+ )
+
+ assert result == "Async results for: async test"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # Verify filtering in async execution
+ captured = handler.captured_inputs[0]
+ assert captured is not None
+ assert "query" in captured
+ assert "state" not in captured
+ assert captured["query"] == "async test"
+
+
+@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed")
+def test_filter_tool_runtime_directly_injected_arg() -> None:
+ """Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered."""
+
+ @tool
+ def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str:
+ """Tool with ToolRuntime parameter.
+
+ Args:
+ query: The search query.
+ limit: Max results.
+ runtime: The tool runtime (directly injected).
+ """
+ return f"Query: {query}, Limit: {limit}"
+
+ handler = CallbackHandlerWithInputCapture(captured_inputs=[])
+
+ # Create a mock ToolRuntime instance
+ class MockRuntime:
+ """Mock ToolRuntime for testing."""
+
+ agent_name = "test_agent"
+ context: dict[str, Any] = {}
+ state: dict[str, Any] = {}
+
+ result = tool_with_runtime.invoke(
+ {
+ "query": "test",
+ "limit": 5,
+ "runtime": MockRuntime(),
+ },
+ config={"callbacks": [handler]},
+ )
+
+ assert result == "Query: test, Limit: 5"
+ assert handler.tool_starts == 1
+ assert len(handler.captured_inputs) == 1
+
+ # Verify that ToolRuntime is filtered out
+ captured = handler.captured_inputs[0]
+ assert captured is not None
+ assert captured == {"query": "test", "limit": 5}
+ assert "runtime" not in captured