mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
fix: filter out injected args from tracing (#33729)
this is CC generated and I want to do a thorough review + update the tests. but should be able to ship today. before eek <img width="637" height="485" alt="Screenshot 2025-10-29 at 12 34 52 PM" src="https://github.com/user-attachments/assets/121def87-fb7b-4847-b9e2-74f37b3b4763" /> now, woo <img width="651" height="158" alt="Screenshot 2025-10-29 at 12 36 09 PM" src="https://github.com/user-attachments/assets/1fc0e19e-a83f-417c-81e2-3aa0028630d6" />
This commit is contained in:
@@ -707,6 +707,35 @@ class ChildTool(BaseTool):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _filter_injected_args(self, tool_input: dict) -> dict:
|
||||
"""Filter out injected tool arguments from the input dictionary.
|
||||
|
||||
Injected arguments are those annotated with InjectedToolArg or its
|
||||
subclasses, or arguments in FILTERED_ARGS like run_manager and callbacks.
|
||||
|
||||
Args:
|
||||
tool_input: The tool input dictionary to filter.
|
||||
|
||||
Returns:
|
||||
A filtered dictionary with injected arguments removed.
|
||||
"""
|
||||
# Start with filtered args from the constant
|
||||
filtered_keys = set[str](FILTERED_ARGS)
|
||||
|
||||
# If we have an args_schema, use it to identify injected args
|
||||
if self.args_schema is not None:
|
||||
try:
|
||||
annotations = get_all_basemodel_annotations(self.args_schema)
|
||||
for field_name, field_type in annotations.items():
|
||||
if _is_injected_arg_type(field_type):
|
||||
filtered_keys.add(field_name)
|
||||
except Exception: # noqa: S110
|
||||
# If we can't get annotations, just use FILTERED_ARGS
|
||||
pass
|
||||
|
||||
# Filter out the injected keys from tool_input
|
||||
return {k: v for k, v in tool_input.items() if k not in filtered_keys}
|
||||
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: str | dict, tool_call_id: str | None
|
||||
) -> tuple[tuple, dict]:
|
||||
@@ -794,17 +823,29 @@ class ChildTool(BaseTool):
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
# Filter out injected arguments from callback inputs
|
||||
filtered_tool_input = (
|
||||
self._filter_injected_args(tool_input)
|
||||
if isinstance(tool_input, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Use filtered inputs for the input_str parameter as well
|
||||
tool_input_str = (
|
||||
tool_input
|
||||
if isinstance(tool_input, str)
|
||||
else str(
|
||||
filtered_tool_input if filtered_tool_input is not None else tool_input
|
||||
)
|
||||
)
|
||||
|
||||
run_manager = callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
tool_input_str,
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
run_id=run_id,
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
inputs=filtered_tool_input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -905,17 +946,30 @@ class ChildTool(BaseTool):
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
# Filter out injected arguments from callback inputs
|
||||
filtered_tool_input = (
|
||||
self._filter_injected_args(tool_input)
|
||||
if isinstance(tool_input, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Use filtered inputs for the input_str parameter as well
|
||||
tool_input_str = (
|
||||
tool_input
|
||||
if isinstance(tool_input, str)
|
||||
else str(
|
||||
filtered_tool_input if filtered_tool_input is not None else tool_input
|
||||
)
|
||||
)
|
||||
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
tool_input_str,
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
run_id=run_id,
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
inputs=filtered_tool_input,
|
||||
**kwargs,
|
||||
)
|
||||
content = None
|
||||
|
||||
@@ -70,6 +70,13 @@ from langchain_core.utils.pydantic import (
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt import ToolRuntime # type: ignore[import-not-found]
|
||||
|
||||
HAS_LANGGRAPH = True
|
||||
except ImportError:
|
||||
HAS_LANGGRAPH = False
|
||||
|
||||
|
||||
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
tool_schema = tool.tool_call_schema
|
||||
@@ -2773,3 +2780,249 @@ def test_tool_args_schema_with_annotated_type() -> None:
|
||||
"type": "array",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CallbackHandlerWithInputCapture(FakeCallbackHandler):
|
||||
"""Callback handler that captures inputs passed to on_tool_start."""
|
||||
|
||||
captured_inputs: list[dict | None] = []
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: Any,
|
||||
parent_run_id: Any | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Capture the inputs passed to on_tool_start."""
|
||||
self.captured_inputs.append(inputs)
|
||||
return super().on_tool_start(
|
||||
serialized,
|
||||
input_str,
|
||||
run_id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_filter_injected_args_from_callbacks() -> None:
|
||||
"""Test that injected tool arguments are filtered from callback inputs."""
|
||||
|
||||
@tool
|
||||
def search_tool(
|
||||
query: str,
|
||||
state: Annotated[dict, InjectedToolArg()],
|
||||
) -> str:
|
||||
"""Search with injected state.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
state: Injected state context.
|
||||
"""
|
||||
return f"Results for: {query}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
result = search_tool.invoke(
|
||||
{"query": "test query", "state": {"user_id": 123}},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
assert result == "Results for: test query"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify that injected 'state' arg is filtered out
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert "query" in captured
|
||||
assert "state" not in captured
|
||||
assert captured["query"] == "test query"
|
||||
|
||||
|
||||
def test_filter_run_manager_from_callbacks() -> None:
|
||||
"""Test that run_manager is filtered from callback inputs."""
|
||||
|
||||
@tool
|
||||
def tool_with_run_manager(
|
||||
message: str,
|
||||
run_manager: CallbackManagerForToolRun | None = None,
|
||||
) -> str:
|
||||
"""Tool with run_manager parameter.
|
||||
|
||||
Args:
|
||||
message: The message to process.
|
||||
run_manager: The callback manager.
|
||||
"""
|
||||
return f"Processed: {message}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
result = tool_with_run_manager.invoke(
|
||||
{"message": "hello"},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
assert result == "Processed: hello"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify that run_manager is filtered out
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert "message" in captured
|
||||
assert "run_manager" not in captured
|
||||
|
||||
|
||||
def test_filter_multiple_injected_args() -> None:
|
||||
"""Test filtering multiple injected arguments from callback inputs."""
|
||||
|
||||
@tool
|
||||
def complex_tool(
|
||||
query: str,
|
||||
limit: int,
|
||||
state: Annotated[dict, InjectedToolArg()],
|
||||
context: Annotated[str, InjectedToolArg()],
|
||||
run_manager: CallbackManagerForToolRun | None = None,
|
||||
) -> str:
|
||||
"""Complex tool with multiple injected args.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results.
|
||||
state: Injected state.
|
||||
context: Injected context.
|
||||
run_manager: The callback manager.
|
||||
"""
|
||||
return f"Query: {query}, Limit: {limit}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
result = complex_tool.invoke(
|
||||
{
|
||||
"query": "test",
|
||||
"limit": 10,
|
||||
"state": {"foo": "bar"},
|
||||
"context": "some context",
|
||||
},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
assert result == "Query: test, Limit: 10"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify that only non-injected args remain
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert captured == {"query": "test", "limit": 10}
|
||||
assert "state" not in captured
|
||||
assert "context" not in captured
|
||||
assert "run_manager" not in captured
|
||||
|
||||
|
||||
def test_no_filtering_for_string_input() -> None:
|
||||
"""Test that string inputs are not filtered (passed as None)."""
|
||||
|
||||
@tool
|
||||
def simple_tool(query: str) -> str:
|
||||
"""Simple tool with string input.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
"""
|
||||
return f"Result: {query}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
result = simple_tool.invoke("test query", config={"callbacks": [handler]})
|
||||
|
||||
assert result == "Result: test query"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# String inputs should result in None for the inputs parameter
|
||||
assert handler.captured_inputs[0] is None
|
||||
|
||||
|
||||
async def test_filter_injected_args_async() -> None:
|
||||
"""Test that injected args are filtered in async tool execution."""
|
||||
|
||||
@tool
|
||||
async def async_search_tool(
|
||||
query: str,
|
||||
state: Annotated[dict, InjectedToolArg()],
|
||||
) -> str:
|
||||
"""Async search with injected state.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
state: Injected state context.
|
||||
"""
|
||||
return f"Async results for: {query}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
result = await async_search_tool.ainvoke(
|
||||
{"query": "async test", "state": {"user_id": 456}},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
assert result == "Async results for: async test"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify filtering in async execution
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert "query" in captured
|
||||
assert "state" not in captured
|
||||
assert captured["query"] == "async test"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_LANGGRAPH, reason="langgraph not installed")
|
||||
def test_filter_tool_runtime_directly_injected_arg() -> None:
|
||||
"""Test that ToolRuntime (a _DirectlyInjectedToolArg) is filtered."""
|
||||
|
||||
@tool
|
||||
def tool_with_runtime(query: str, limit: int, runtime: ToolRuntime) -> str:
|
||||
"""Tool with ToolRuntime parameter.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Max results.
|
||||
runtime: The tool runtime (directly injected).
|
||||
"""
|
||||
return f"Query: {query}, Limit: {limit}"
|
||||
|
||||
handler = CallbackHandlerWithInputCapture(captured_inputs=[])
|
||||
|
||||
# Create a mock ToolRuntime instance
|
||||
class MockRuntime:
|
||||
"""Mock ToolRuntime for testing."""
|
||||
|
||||
agent_name = "test_agent"
|
||||
context: dict[str, Any] = {}
|
||||
state: dict[str, Any] = {}
|
||||
|
||||
result = tool_with_runtime.invoke(
|
||||
{
|
||||
"query": "test",
|
||||
"limit": 5,
|
||||
"runtime": MockRuntime(),
|
||||
},
|
||||
config={"callbacks": [handler]},
|
||||
)
|
||||
|
||||
assert result == "Query: test, Limit: 5"
|
||||
assert handler.tool_starts == 1
|
||||
assert len(handler.captured_inputs) == 1
|
||||
|
||||
# Verify that ToolRuntime is filtered out
|
||||
captured = handler.captured_inputs[0]
|
||||
assert captured is not None
|
||||
assert captured == {"query": "test", "limit": 5}
|
||||
assert "runtime" not in captured
|
||||
|
||||
Reference in New Issue
Block a user