fix: filter out injected args from tracing (#33729)

this is CC generated and I want to do a thorough review + update the
tests. but should be able to ship today.

before eek

<img width="637" height="485" alt="Screenshot 2025-10-29 at 12 34 52 PM"
src="https://github.com/user-attachments/assets/121def87-fb7b-4847-b9e2-74f37b3b4763"
/>

now, woo

<img width="651" height="158" alt="Screenshot 2025-10-29 at 12 36 09 PM"
src="https://github.com/user-attachments/assets/1fc0e19e-a83f-417c-81e2-3aa0028630d6"
/>
This commit is contained in:
Sydney Runkle
2025-10-29 15:20:53 -07:00
committed by GitHub
parent d218936763
commit 76dd656f2a
2 changed files with 319 additions and 12 deletions

View File

@@ -707,6 +707,35 @@ class ChildTool(BaseTool):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
def _filter_injected_args(self, tool_input: dict) -> dict:
"""Filter out injected tool arguments from the input dictionary.
Injected arguments are those annotated with InjectedToolArg or its
subclasses, or arguments in FILTERED_ARGS like run_manager and callbacks.
Args:
tool_input: The tool input dictionary to filter.
Returns:
A filtered dictionary with injected arguments removed.
"""
# Start with filtered args from the constant
filtered_keys = set[str](FILTERED_ARGS)
# If we have an args_schema, use it to identify injected args
if self.args_schema is not None:
try:
annotations = get_all_basemodel_annotations(self.args_schema)
for field_name, field_type in annotations.items():
if _is_injected_arg_type(field_type):
filtered_keys.add(field_name)
except Exception: # noqa: S110
# If we can't get annotations, just use FILTERED_ARGS
pass
# Filter out the injected keys from tool_input
return {k: v for k, v in tool_input.items() if k not in filtered_keys}
def _to_args_and_kwargs(
self, tool_input: str | dict, tool_call_id: str | None
) -> tuple[tuple, dict]:
@@ -794,17 +823,29 @@ class ChildTool(BaseTool):
self.metadata,
)
# Filter out injected arguments from callback inputs
filtered_tool_input = (
self._filter_injected_args(tool_input)
if isinstance(tool_input, dict)
else None
)
# Use filtered inputs for the input_str parameter as well
tool_input_str = (
tool_input
if isinstance(tool_input, str)
else str(
filtered_tool_input if filtered_tool_input is not None else tool_input
)
)
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
tool_input_str,
color=start_color,
name=run_name,
run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# TODO: will need to address issue via a patch.
inputs=tool_input if isinstance(tool_input, dict) else None,
inputs=filtered_tool_input,
**kwargs,
)
@@ -905,17 +946,30 @@ class ChildTool(BaseTool):
metadata,
self.metadata,
)
# Filter out injected arguments from callback inputs
filtered_tool_input = (
self._filter_injected_args(tool_input)
if isinstance(tool_input, dict)
else None
)
# Use filtered inputs for the input_str parameter as well
tool_input_str = (
tool_input
if isinstance(tool_input, str)
else str(
filtered_tool_input if filtered_tool_input is not None else tool_input
)
)
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
tool_input_str,
color=start_color,
name=run_name,
run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# TODO: will need to address issue via a patch.
inputs=tool_input if isinstance(tool_input, dict) else None,
inputs=filtered_tool_input,
**kwargs,
)
content = None

View File

@@ -70,6 +70,13 @@ from langchain_core.utils.pydantic import (
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
try:
from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
HAS_LANGGRAPH = True
except ImportError:
HAS_LANGGRAPH = False
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
tool_schema = tool.tool_call_schema
@@ -2773,3 +2780,249 @@ def test_tool_args_schema_with_annotated_type() -> None:
"type": "array",
}
}
class CallbackHandlerWithInputCapture(FakeCallbackHandler):
"""Callback handler that captures inputs passed to on_tool_start."""
captured_inputs: list[dict | None] = []
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: Any,
parent_run_id: Any | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
"""Capture the inputs passed to on_tool_start."""
self.captured_inputs.append(inputs)
return super().on_tool_start(
serialized,
input_str,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
inputs=inputs,
**kwargs,
)
def test_filter_injected_args_from_callbacks() -> None:
"""Test that injected tool arguments are filtered from callback inputs."""
@tool
def search_tool(
query: str,
state: Annotated[dict, InjectedToolArg()],
) -> str:
"""Search with injected state.
Args:
query: The search query.
state: Injected state context.
"""
return f"Results for: {query}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = search_tool.invoke(
{"query": "test query", "state": {"user_id": 123}},
config={"callbacks": [handler]},
)
assert result == "Results for: test query"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that injected 'state' arg is filtered out
captured = handler.captured_inputs[0]
assert captured is not None
assert "query" in captured
assert "state" not in captured
assert captured["query"] == "test query"
def test_filter_run_manager_from_callbacks() -> None:
"""Test that run_manager is filtered from callback inputs."""
@tool
def tool_with_run_manager(
message: str,
run_manager: CallbackManagerForToolRun | None = None,
) -> str:
"""Tool with run_manager parameter.
Args:
message: The message to process.
run_manager: The callback manager.
"""
return f"Processed: {message}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = tool_with_run_manager.invoke(
{"message": "hello"},
config={"callbacks": [handler]},
)
assert result == "Processed: hello"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that run_manager is filtered out
captured = handler.captured_inputs[0]
assert captured is not None
assert "message" in captured
assert "run_manager" not in captured
def test_filter_multiple_injected_args() -> None:
"""Test filtering multiple injected arguments from callback inputs."""
@tool
def complex_tool(
query: str,
limit: int,
state: Annotated[dict, InjectedToolArg()],
context: Annotated[str, InjectedToolArg()],
run_manager: CallbackManagerForToolRun | None = None,
) -> str:
"""Complex tool with multiple injected args.
Args:
query: The search query.
limit: Maximum number of results.
state: Injected state.
context: Injected context.
run_manager: The callback manager.
"""
return f"Query: {query}, Limit: {limit}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = complex_tool.invoke(
{
"query": "test",
"limit": 10,
"state": {"foo": "bar"},
"context": "some context",
},
config={"callbacks": [handler]},
)
assert result == "Query: test, Limit: 10"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that only non-injected args remain
captured = handler.captured_inputs[0]
assert captured is not None
assert captured == {"query": "test", "limit": 10}
assert "state" not in captured
assert "context" not in captured
assert "run_manager" not in captured
def test_no_filtering_for_string_input() -> None:
"""Test that string inputs are not filtered (passed as None)."""
@tool
def simple_tool(query: str) -> str:
"""Simple tool with string input.
Args:
query: The query string.
"""
return f"Result: {query}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = simple_tool.invoke("test query", config={"callbacks": [handler]})
assert result == "Result: test query"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# String inputs should result in None for the inputs parameter
assert handler.captured_inputs[0] is None
async def test_filter_injected_args_async() -> None:
"""Test that injected args are filtered in async tool execution."""
@tool
async def async_search_tool(
query: str,
state: Annotated[dict, InjectedToolArg()],
) -> str:
"""Async search with injected state.
Args:
query: The search query.
state: Injected state context.
"""
return f"Async results for: {query}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
result = await async_search_tool.ainvoke(
{"query": "async test", "state": {"user_id": 456}},
config={"callbacks": [handler]},
)
assert result == "Async results for: async test"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify filtering in async execution
captured = handler.captured_inputs[0]
assert captured is not None
assert "query" in captured
assert "state" not in captured
assert captured["query"] == "async test"
@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed")
def test_filter_tool_runtime_directly_injected_arg() -> None:
"""Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered."""
@tool
def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str:
"""Tool with ToolRuntime parameter.
Args:
query: The search query.
limit: Max results.
runtime: The tool runtime (directly injected).
"""
return f"Query: {query}, Limit: {limit}"
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
# Create a mock ToolRuntime instance
class MockRuntime:
"""Mock ToolRuntime for testing."""
agent_name = "test_agent"
context: dict[str, Any] = {}
state: dict[str, Any] = {}
result = tool_with_runtime.invoke(
{
"query": "test",
"limit": 5,
"runtime": MockRuntime(),
},
config={"callbacks": [handler]},
)
assert result == "Query: test, Limit: 5"
assert handler.tool_starts == 1
assert len(handler.captured_inputs) == 1
# Verify that ToolRuntime is filtered out
captured = handler.captured_inputs[0]
assert captured is not None
assert captured == {"query": "test", "limit": 5}
assert "runtime" not in captured