From 76dd656f2a15935c8dbfcc93acbbcd61a91b0661 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:20:53 -0700 Subject: [PATCH] fix: filter out injected args from tracing (#33729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit this is CC generated and I want to do a thorough review + update the tests. but should be able to ship today. before eek Screenshot 2025-10-29 at 12 34 52 PM now, woo Screenshot 2025-10-29 at 12 36 09 PM --- libs/core/langchain_core/tools/base.py | 78 +++++-- libs/core/tests/unit_tests/test_tools.py | 253 +++++++++++++++++++++++ 2 files changed, 319 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index def3b227089..588d71efdc8 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -707,6 +707,35 @@ class ChildTool(BaseTool): kwargs["run_manager"] = kwargs["run_manager"].get_sync() return await run_in_executor(None, self._run, *args, **kwargs) + def _filter_injected_args(self, tool_input: dict) -> dict: + """Filter out injected tool arguments from the input dictionary. + + Injected arguments are those annotated with InjectedToolArg or its + subclasses, or arguments in FILTERED_ARGS like run_manager and callbacks. + + Args: + tool_input: The tool input dictionary to filter. + + Returns: + A filtered dictionary with injected arguments removed. + """ + # Start with filtered args from the constant + filtered_keys = set[str](FILTERED_ARGS) + + # If we have an args_schema, use it to identify injected args + if self.args_schema is not None: + try: + annotations = get_all_basemodel_annotations(self.args_schema) + for field_name, field_type in annotations.items(): + if _is_injected_arg_type(field_type): + filtered_keys.add(field_name) + except Exception: # noqa: S110 + # If we can't get annotations, just use FILTERED_ARGS + pass + + # Filter out the injected keys from tool_input + return {k: v for k, v in tool_input.items() if k not in filtered_keys} + def _to_args_and_kwargs( self, tool_input: str | dict, tool_call_id: str | None ) -> tuple[tuple, dict]: @@ -794,17 +823,29 @@ class ChildTool(BaseTool): self.metadata, ) + # Filter out injected arguments from callback inputs + filtered_tool_input = ( + self._filter_injected_args(tool_input) + if isinstance(tool_input, dict) + else None + ) + + # Use filtered inputs for the input_str parameter as well + tool_input_str = ( + tool_input + if isinstance(tool_input, str) + else str( + filtered_tool_input if filtered_tool_input is not None else tool_input + ) + ) + run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), + tool_input_str, color=start_color, name=run_name, run_id=run_id, - # Inputs by definition should always be dicts. - # For now, it's unclear whether this assumption is ever violated, - # but if it is we will send a `None` value to the callback instead - # TODO: will need to address issue via a patch. - inputs=tool_input if isinstance(tool_input, dict) else None, + inputs=filtered_tool_input, **kwargs, ) @@ -905,17 +946,30 @@ class ChildTool(BaseTool): metadata, self.metadata, ) + + # Filter out injected arguments from callback inputs + filtered_tool_input = ( + self._filter_injected_args(tool_input) + if isinstance(tool_input, dict) + else None + ) + + # Use filtered inputs for the input_str parameter as well + tool_input_str = ( + tool_input + if isinstance(tool_input, str) + else str( + filtered_tool_input if filtered_tool_input is not None else tool_input + ) + ) + run_manager = await callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), + tool_input_str, color=start_color, name=run_name, run_id=run_id, - # Inputs by definition should always be dicts. - # For now, it's unclear whether this assumption is ever violated, - # but if it is we will send a `None` value to the callback instead - # TODO: will need to address issue via a patch. - inputs=tool_input if isinstance(tool_input, dict) else None, + inputs=filtered_tool_input, **kwargs, ) content = None diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 9fdd2afc819..b6e6020e841 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -70,6 +70,13 @@ from langchain_core.utils.pydantic import ( from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _normalize_schema, _schema +try: + from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found] + + HAS_LANGGRAPH = True +except ImportError: + HAS_LANGGRAPH = False + def _get_tool_call_json_schema(tool: BaseTool) -> dict: tool_schema = tool.tool_call_schema @@ -2773,3 +2780,249 @@ def test_tool_args_schema_with_annotated_type() -> None: "type": "array", } } + + +class CallbackHandlerWithInputCapture(FakeCallbackHandler): + """Callback handler that captures inputs passed to on_tool_start.""" + + captured_inputs: list[dict | None] = [] + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: Any, + parent_run_id: Any | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + """Capture the inputs passed to on_tool_start.""" + self.captured_inputs.append(inputs) + return super().on_tool_start( + serialized, + input_str, + run_id=run_id, + parent_run_id=parent_run_id, + tags=tags, + metadata=metadata, + inputs=inputs, + **kwargs, + ) + + +def test_filter_injected_args_from_callbacks() -> None: + """Test that injected tool arguments are filtered from callback inputs.""" + + @tool + def search_tool( + query: str, + state: Annotated[dict, InjectedToolArg()], + ) -> str: + """Search with injected state. + + Args: + query: The search query. + state: Injected state context. + """ + return f"Results for: {query}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + result = search_tool.invoke( + {"query": "test query", "state": {"user_id": 123}}, + config={"callbacks": [handler]}, + ) + + assert result == "Results for: test query" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # Verify that injected 'state' arg is filtered out + captured = handler.captured_inputs[0] + assert captured is not None + assert "query" in captured + assert "state" not in captured + assert captured["query"] == "test query" + + +def test_filter_run_manager_from_callbacks() -> None: + """Test that run_manager is filtered from callback inputs.""" + + @tool + def tool_with_run_manager( + message: str, + run_manager: CallbackManagerForToolRun | None = None, + ) -> str: + """Tool with run_manager parameter. + + Args: + message: The message to process. + run_manager: The callback manager. + """ + return f"Processed: {message}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + result = tool_with_run_manager.invoke( + {"message": "hello"}, + config={"callbacks": [handler]}, + ) + + assert result == "Processed: hello" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # Verify that run_manager is filtered out + captured = handler.captured_inputs[0] + assert captured is not None + assert "message" in captured + assert "run_manager" not in captured + + +def test_filter_multiple_injected_args() -> None: + """Test filtering multiple injected arguments from callback inputs.""" + + @tool + def complex_tool( + query: str, + limit: int, + state: Annotated[dict, InjectedToolArg()], + context: Annotated[str, InjectedToolArg()], + run_manager: CallbackManagerForToolRun | None = None, + ) -> str: + """Complex tool with multiple injected args. + + Args: + query: The search query. + limit: Maximum number of results. + state: Injected state. + context: Injected context. + run_manager: The callback manager. + """ + return f"Query: {query}, Limit: {limit}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + result = complex_tool.invoke( + { + "query": "test", + "limit": 10, + "state": {"foo": "bar"}, + "context": "some context", + }, + config={"callbacks": [handler]}, + ) + + assert result == "Query: test, Limit: 10" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # Verify that only non-injected args remain + captured = handler.captured_inputs[0] + assert captured is not None + assert captured == {"query": "test", "limit": 10} + assert "state" not in captured + assert "context" not in captured + assert "run_manager" not in captured + + +def test_no_filtering_for_string_input() -> None: + """Test that string inputs are not filtered (passed as None).""" + + @tool + def simple_tool(query: str) -> str: + """Simple tool with string input. + + Args: + query: The query string. + """ + return f"Result: {query}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + result = simple_tool.invoke("test query", config={"callbacks": [handler]}) + + assert result == "Result: test query" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # String inputs should result in None for the inputs parameter + assert handler.captured_inputs[0] is None + + +async def test_filter_injected_args_async() -> None: + """Test that injected args are filtered in async tool execution.""" + + @tool + async def async_search_tool( + query: str, + state: Annotated[dict, InjectedToolArg()], + ) -> str: + """Async search with injected state. + + Args: + query: The search query. + state: Injected state context. + """ + return f"Async results for: {query}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + result = await async_search_tool.ainvoke( + {"query": "async test", "state": {"user_id": 456}}, + config={"callbacks": [handler]}, + ) + + assert result == "Async results for: async test" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # Verify filtering in async execution + captured = handler.captured_inputs[0] + assert captured is not None + assert "query" in captured + assert "state" not in captured + assert captured["query"] == "async test" + + +@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed") +def test_filter_tool_runtime_directly_injected_arg() -> None: + """Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered.""" + + @tool + def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str: + """Tool with ToolRuntime parameter. + + Args: + query: The search query. + limit: Max results. + runtime: The tool runtime (directly injected). + """ + return f"Query: {query}, Limit: {limit}" + + handler = CallbackHandlerWithInputCapture(captured_inputs=[]) + + # Create a mock ToolRuntime instance + class MockRuntime: + """Mock ToolRuntime for testing.""" + + agent_name = "test_agent" + context: dict[str, Any] = {} + state: dict[str, Any] = {} + + result = tool_with_runtime.invoke( + { + "query": "test", + "limit": 5, + "runtime": MockRuntime(), + }, + config={"callbacks": [handler]}, + ) + + assert result == "Query: test, Limit: 5" + assert handler.tool_starts == 1 + assert len(handler.captured_inputs) == 1 + + # Verify that ToolRuntime is filtered out + captured = handler.captured_inputs[0] + assert captured is not None + assert captured == {"query": "test", "limit": 5} + assert "runtime" not in captured