feat(langchain): add interrupt_mode and when predicate to HumanInTheLoopMiddleware (#37579)

Adds an optional `when` predicate to `InterruptOnConfig`, allowing
dynamic per-tool-call control over whether a HITL interrupt fires.

---

**`when` predicate in `InterruptOnConfig`**

```python
class InterruptOnConfig(TypedDict):
    allowed_decisions: list[DecisionType]
    description: NotRequired[str | _DescriptionFactory]
    args_schema: NotRequired[dict[str, Any]]
    when: NotRequired[Callable[[ToolCallRequest], bool]]  # new
```

When provided, `when` is called before adding a tool call to the batch
interrupt. If it returns `False`, the call is auto-approved and
excluded. If it returns `True` (or `when` is absent), existing behaviour
is unchanged.

The predicate receives a `ToolCallRequest` with:
- `tool_call` — the raw tool call dict (name, args, id)
- `tool` — `None` (no `BaseTool` instance is available at the
`after_model` stage)
- `state` — current agent state
- `runtime` — a `ToolRuntime` constructed from the node-level `Runtime`,
with `tool_call_id` populated

Example:

```python
HumanInTheLoopMiddleware(
    interrupt_on={
        "delete_file": InterruptOnConfig(
            allowed_decisions=["approve", "reject"],
            when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
        )
    }
)
```

This change is fully backwards-compatible — `when` is `NotRequired` and
existing configs without it behave identically.

> This PR was developed with AI-agent assistance.
This commit is contained in:
Sydney Runkle
2026-06-01 16:05:35 -04:00
committed by GitHub
parent bc5f1517cf
commit 36be77b0f1
2 changed files with 176 additions and 3 deletions

View File

@@ -1,9 +1,12 @@
"""Human in the loop middleware.""" """Human in the loop middleware."""
from typing import Any, Literal, Protocol from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Protocol
from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime from langgraph.config import get_config
from langgraph.prebuilt.tool_node import ToolRuntime
from langgraph.types import interrupt from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
@@ -13,8 +16,14 @@ from langchain.agents.middleware.types import (
ContextT, ContextT,
ResponseT, ResponseT,
StateT, StateT,
ToolCallRequest,
) )
if TYPE_CHECKING:
from collections.abc import Callable
from langgraph.runtime import Runtime
class Action(TypedDict): class Action(TypedDict):
"""Represents an action with a name and args.""" """Represents an action with a name and args."""
@@ -178,6 +187,27 @@ class InterruptOnConfig(TypedDict):
args_schema: NotRequired[dict[str, Any]] args_schema: NotRequired[dict[str, Any]]
"""JSON schema for the args associated with the action, if edits are allowed.""" """JSON schema for the args associated with the action, if edits are allowed."""
when: NotRequired[Callable[[ToolCallRequest], bool]]
"""Optional predicate controlling whether to interrupt for a given tool call.
Receives a `ToolCallRequest` and returns `True` to interrupt or `False` to
auto-approve. Works in both `"batch"` and `"per_call"` modes.
In `"batch"` mode the request is constructed with `tool=None` and
`runtime` set to the node-level `Runtime` (not a `ToolRuntime`), so
`request.runtime.tool_call_id` and `request.runtime.tools` are not available.
In `"per_call"` mode the full `ToolCallRequest` from `wrap_tool_call` is passed.
Example:
```python
# Only interrupt delete_file calls targeting /etc
config = InterruptOnConfig(
allowed_decisions=["approve", "reject"],
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
)
```
"""
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
"""Human in the loop middleware.""" """Human in the loop middleware."""
@@ -203,6 +233,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
The `InterruptOnConfig` can include a `description` field (`str` or The `InterruptOnConfig` can include a `description` field (`str` or
`Callable`) for custom formatting of the interrupt description. `Callable`) for custom formatting of the interrupt description.
A `when` predicate can also be provided to dynamically control
whether a tool call triggers an interrupt.
description_prefix: The prefix to use when constructing action requests. description_prefix: The prefix to use when constructing action requests.
This is used to provide context about the tool call and the action being This is used to provide context about the tool call and the action being
@@ -310,6 +343,39 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
) )
raise ValueError(msg) raise ValueError(msg)
def _should_interrupt(
self,
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> bool:
"""Return False if the `when` predicate rejects this tool call, True otherwise."""
when = config.get("when")
if when is None:
return True
try:
runnable_config = get_config()
except RuntimeError:
runnable_config = {}
tool_runtime = ToolRuntime(
state=state,
context=runtime.context,
config=runnable_config,
stream_writer=runtime.stream_writer,
tool_call_id=tool_call["id"],
store=runtime.store,
execution_info=runtime.execution_info,
server_info=runtime.server_info,
)
req = ToolCallRequest(
tool_call=tool_call,
tool=None,
state=state,
runtime=tool_runtime, # type: ignore[arg-type]
)
return when(req)
def after_model( def after_model(
self, state: AgentState[Any], runtime: Runtime[ContextT] self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
@@ -341,6 +407,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
for idx, tool_call in enumerate(last_ai_msg.tool_calls): for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if (config := self.interrupt_on.get(tool_call["name"])) is not None: if (config := self.interrupt_on.get(tool_call["name"])) is not None:
if not self._should_interrupt(tool_call, config, state, runtime):
continue
action_request, review_config = self._create_action_and_config( action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime tool_call, config, state, runtime
) )

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langgraph.prebuilt.tool_node import ToolRuntime
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from langchain.agents.middleware import InterruptOnConfig from langchain.agents.middleware import InterruptOnConfig
@@ -11,7 +12,7 @@ from langchain.agents.middleware.human_in_the_loop import (
Action, Action,
HumanInTheLoopMiddleware, HumanInTheLoopMiddleware,
) )
from langchain.agents.middleware.types import AgentState from langchain.agents.middleware.types import AgentState, ToolCallRequest
def test_human_in_the_loop_middleware_initialization() -> None: def test_human_in_the_loop_middleware_initialization() -> None:
@@ -883,3 +884,107 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
assert isinstance(tool_message, ToolMessage) assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "Rejected tool B" assert tool_message.content == "Rejected tool B"
assert tool_message.tool_call_id == "id_b" assert tool_message.tool_call_id == "id_b"
# ---------------------------------------------------------------------------
# when predicate
# ---------------------------------------------------------------------------
def test_when_predicate_batch_skips_interrupt_when_false() -> None:
"""`when` returning False prevents the tool call from joining the batch interrupt."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=lambda req: req.tool_call["args"].get("risky", False),
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"risky": False}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch("langchain.agents.middleware.human_in_the_loop.interrupt") as mock_interrupt,
):
result = middleware.after_model(state, Runtime())
mock_interrupt.assert_not_called()
assert result is None
def test_when_predicate_batch_fires_interrupt_when_true() -> None:
"""`when` returning True allows the tool call to trigger the batch interrupt."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=lambda req: req.tool_call["args"].get("risky", False),
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"risky": True}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
),
):
result = middleware.after_model(state, Runtime())
assert result is not None
def test_when_predicate_receives_correct_args() -> None:
"""The when predicate receives a ToolCallRequest with correct values and a ToolRuntime."""
captured: list[Any] = []
def capture_when(req: ToolCallRequest) -> bool:
captured.append(req)
return True
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=capture_when,
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"val": 42}, "id": "tc-1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
runtime = Runtime()
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
),
):
middleware.after_model(state, runtime)
assert len(captured) == 1
req = captured[0]
assert req.tool_call["name"] == "test_tool"
assert req.tool_call["args"] == {"val": 42}
assert req.tool is None
assert req.state is state
assert isinstance(req.runtime, ToolRuntime)
assert req.runtime.tool_call_id == "tc-1"
assert req.runtime.state is state
assert req.runtime.context is runtime.context
assert req.runtime.store is runtime.store