diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index ac1a01fa50f..3fa3aaedfa5 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -1,9 +1,12 @@ """Human in the loop middleware.""" -from typing import Any, Literal, Protocol +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Protocol from langchain_core.messages import AIMessage, ToolCall, ToolMessage -from langgraph.runtime import Runtime +from langgraph.config import get_config +from langgraph.prebuilt.tool_node import ToolRuntime from langgraph.types import interrupt from typing_extensions import NotRequired, TypedDict @@ -13,8 +16,14 @@ from langchain.agents.middleware.types import ( ContextT, ResponseT, StateT, + ToolCallRequest, ) +if TYPE_CHECKING: + from collections.abc import Callable + + from langgraph.runtime import Runtime + class Action(TypedDict): """Represents an action with a name and args.""" @@ -178,6 +187,27 @@ class InterruptOnConfig(TypedDict): args_schema: NotRequired[dict[str, Any]] """JSON schema for the args associated with the action, if edits are allowed.""" + when: NotRequired[Callable[[ToolCallRequest], bool]] + """Optional predicate controlling whether to interrupt for a given tool call. + + Receives a `ToolCallRequest` and returns `True` to interrupt or `False` to + auto-approve. Works in both `"batch"` and `"per_call"` modes. + + In `"batch"` mode the request is constructed with `tool=None` and + `runtime` set to the node-level `Runtime` (not a `ToolRuntime`), so + `request.runtime.tool_call_id` and `request.runtime.tools` are not available. + In `"per_call"` mode the full `ToolCallRequest` from `wrap_tool_call` is passed. + + Example: + ```python + # Only interrupt delete_file calls targeting /etc + config = InterruptOnConfig( + allowed_decisions=["approve", "reject"], + when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"), + ) + ``` + """ + class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): """Human in the loop middleware.""" @@ -203,6 +233,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): The `InterruptOnConfig` can include a `description` field (`str` or `Callable`) for custom formatting of the interrupt description. + + A `when` predicate can also be provided to dynamically control + whether a tool call triggers an interrupt. description_prefix: The prefix to use when constructing action requests. This is used to provide context about the tool call and the action being @@ -310,6 +343,39 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): ) raise ValueError(msg) + def _should_interrupt( + self, + tool_call: ToolCall, + config: InterruptOnConfig, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> bool: + """Return False if the `when` predicate rejects this tool call, True otherwise.""" + when = config.get("when") + if when is None: + return True + try: + runnable_config = get_config() + except RuntimeError: + runnable_config = {} + tool_runtime = ToolRuntime( + state=state, + context=runtime.context, + config=runnable_config, + stream_writer=runtime.stream_writer, + tool_call_id=tool_call["id"], + store=runtime.store, + execution_info=runtime.execution_info, + server_info=runtime.server_info, + ) + req = ToolCallRequest( + tool_call=tool_call, + tool=None, + state=state, + runtime=tool_runtime, # type: ignore[arg-type] + ) + return when(req) + def after_model( self, state: AgentState[Any], runtime: Runtime[ContextT] ) -> dict[str, Any] | None: @@ -341,6 +407,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): for idx, tool_call in enumerate(last_ai_msg.tool_calls): if (config := self.interrupt_on.get(tool_call["name"])) is not None: + if not self._should_interrupt(tool_call, config, state, runtime): + continue action_request, review_config = self._create_action_and_config( tool_call, config, state, runtime ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py index 5e03b0bed55..720bdeba89e 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage +from langgraph.prebuilt.tool_node import ToolRuntime from langgraph.runtime import Runtime from langchain.agents.middleware import InterruptOnConfig @@ -11,7 +12,7 @@ from langchain.agents.middleware.human_in_the_loop import ( Action, HumanInTheLoopMiddleware, ) -from langchain.agents.middleware.types import AgentState +from langchain.agents.middleware.types import AgentState, ToolCallRequest def test_human_in_the_loop_middleware_initialization() -> None: @@ -883,3 +884,107 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None: assert isinstance(tool_message, ToolMessage) assert tool_message.content == "Rejected tool B" assert tool_message.tool_call_id == "id_b" + + +# --------------------------------------------------------------------------- +# when predicate +# --------------------------------------------------------------------------- + + +def test_when_predicate_batch_skips_interrupt_when_false() -> None: + """`when` returning False prevents the tool call from joining the batch interrupt.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "test_tool": InterruptOnConfig( + allowed_decisions=["approve"], + when=lambda req: req.tool_call["args"].get("risky", False), + ) + } + ) + ai_message = AIMessage( + content="...", + tool_calls=[{"name": "test_tool", "args": {"risky": False}, "id": "1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + with ( + patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}), + patch("langchain.agents.middleware.human_in_the_loop.interrupt") as mock_interrupt, + ): + result = middleware.after_model(state, Runtime()) + mock_interrupt.assert_not_called() + + assert result is None + + +def test_when_predicate_batch_fires_interrupt_when_true() -> None: + """`when` returning True allows the tool call to trigger the batch interrupt.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "test_tool": InterruptOnConfig( + allowed_decisions=["approve"], + when=lambda req: req.tool_call["args"].get("risky", False), + ) + } + ) + ai_message = AIMessage( + content="...", + tool_calls=[{"name": "test_tool", "args": {"risky": True}, "id": "1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + with ( + patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}), + patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value={"decisions": [{"type": "approve"}]}, + ), + ): + result = middleware.after_model(state, Runtime()) + + assert result is not None + + +def test_when_predicate_receives_correct_args() -> None: + """The when predicate receives a ToolCallRequest with correct values and a ToolRuntime.""" + captured: list[Any] = [] + + def capture_when(req: ToolCallRequest) -> bool: + captured.append(req) + return True + + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "test_tool": InterruptOnConfig( + allowed_decisions=["approve"], + when=capture_when, + ) + } + ) + ai_message = AIMessage( + content="...", + tool_calls=[{"name": "test_tool", "args": {"val": 42}, "id": "tc-1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + runtime = Runtime() + + with ( + patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}), + patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value={"decisions": [{"type": "approve"}]}, + ), + ): + middleware.after_model(state, runtime) + + assert len(captured) == 1 + req = captured[0] + assert req.tool_call["name"] == "test_tool" + assert req.tool_call["args"] == {"val": 42} + assert req.tool is None + assert req.state is state + assert isinstance(req.runtime, ToolRuntime) + assert req.runtime.tool_call_id == "tc-1" + assert req.runtime.state is state + assert req.runtime.context is runtime.context + assert req.runtime.store is runtime.store