feat(langchain): add interrupt_mode and when predicate to HumanInTheLoopMiddleware (#37579)

Adds an optional `when` predicate to `InterruptOnConfig`, allowing
dynamic per-tool-call control over whether a HITL interrupt fires.

---

**`when` predicate in `InterruptOnConfig`**

```python
class InterruptOnConfig(TypedDict):
    allowed_decisions: list[DecisionType]
    description: NotRequired[str | _DescriptionFactory]
    args_schema: NotRequired[dict[str, Any]]
    when: NotRequired[Callable[[ToolCallRequest], bool]]  # new
```

When provided, `when` is called before adding a tool call to the batch
interrupt. If it returns `False`, the call is auto-approved and
excluded. If it returns `True` (or `when` is absent), existing behaviour
is unchanged.

The predicate receives a `ToolCallRequest` with:
- `tool_call` — the raw tool call dict (name, args, id)
- `tool` — `None` (no `BaseTool` instance is available at the
`after_model` stage)
- `state` — current agent state
- `runtime` — a `ToolRuntime` constructed from the node-level `Runtime`,
with `tool_call_id` populated

Example:

```python
HumanInTheLoopMiddleware(
    interrupt_on={
        "delete_file": InterruptOnConfig(
            allowed_decisions=["approve", "reject"],
            when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
        )
    }
)
```

This change is fully backwards-compatible — `when` is `NotRequired` and
existing configs without it behave identically.

> This PR was developed with AI-agent assistance.
This commit is contained in:
Sydney Runkle
2026-06-01 16:05:35 -04:00
committed by GitHub
parent bc5f1517cf
commit 36be77b0f1
2 changed files with 176 additions and 3 deletions

View File

@@ -1,9 +1,12 @@
"""Human in the loop middleware."""
from typing import Any, Literal, Protocol
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Protocol
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_config
from langgraph.prebuilt.tool_node import ToolRuntime
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
@@ -13,8 +16,14 @@ from langchain.agents.middleware.types import (
ContextT,
ResponseT,
StateT,
ToolCallRequest,
)
if TYPE_CHECKING:
from collections.abc import Callable
from langgraph.runtime import Runtime
class Action(TypedDict):
"""Represents an action with a name and args."""
@@ -178,6 +187,27 @@ class InterruptOnConfig(TypedDict):
args_schema: NotRequired[dict[str, Any]]
"""JSON schema for the args associated with the action, if edits are allowed."""
when: NotRequired[Callable[[ToolCallRequest], bool]]
"""Optional predicate controlling whether to interrupt for a given tool call.
Receives a `ToolCallRequest` and returns `True` to interrupt or `False` to
auto-approve. Works in both `"batch"` and `"per_call"` modes.
In `"batch"` mode the request is constructed with `tool=None` and
`runtime` set to the node-level `Runtime` (not a `ToolRuntime`), so
`request.runtime.tool_call_id` and `request.runtime.tools` are not available.
In `"per_call"` mode the full `ToolCallRequest` from `wrap_tool_call` is passed.
Example:
```python
# Only interrupt delete_file calls targeting /etc
config = InterruptOnConfig(
allowed_decisions=["approve", "reject"],
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
)
```
"""
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
"""Human in the loop middleware."""
@@ -203,6 +233,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
The `InterruptOnConfig` can include a `description` field (`str` or
`Callable`) for custom formatting of the interrupt description.
A `when` predicate can also be provided to dynamically control
whether a tool call triggers an interrupt.
description_prefix: The prefix to use when constructing action requests.
This is used to provide context about the tool call and the action being
@@ -310,6 +343,39 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
)
raise ValueError(msg)
def _should_interrupt(
self,
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> bool:
"""Return False if the `when` predicate rejects this tool call, True otherwise."""
when = config.get("when")
if when is None:
return True
try:
runnable_config = get_config()
except RuntimeError:
runnable_config = {}
tool_runtime = ToolRuntime(
state=state,
context=runtime.context,
config=runnable_config,
stream_writer=runtime.stream_writer,
tool_call_id=tool_call["id"],
store=runtime.store,
execution_info=runtime.execution_info,
server_info=runtime.server_info,
)
req = ToolCallRequest(
tool_call=tool_call,
tool=None,
state=state,
runtime=tool_runtime, # type: ignore[arg-type]
)
return when(req)
def after_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
@@ -341,6 +407,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
if not self._should_interrupt(tool_call, config, state, runtime):
continue
action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime
)

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langgraph.prebuilt.tool_node import ToolRuntime
from langgraph.runtime import Runtime
from langchain.agents.middleware import InterruptOnConfig
@@ -11,7 +12,7 @@ from langchain.agents.middleware.human_in_the_loop import (
Action,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.types import AgentState
from langchain.agents.middleware.types import AgentState, ToolCallRequest
def test_human_in_the_loop_middleware_initialization() -> None:
@@ -883,3 +884,107 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "Rejected tool B"
assert tool_message.tool_call_id == "id_b"
# ---------------------------------------------------------------------------
# when predicate
# ---------------------------------------------------------------------------
def test_when_predicate_batch_skips_interrupt_when_false() -> None:
"""`when` returning False prevents the tool call from joining the batch interrupt."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=lambda req: req.tool_call["args"].get("risky", False),
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"risky": False}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch("langchain.agents.middleware.human_in_the_loop.interrupt") as mock_interrupt,
):
result = middleware.after_model(state, Runtime())
mock_interrupt.assert_not_called()
assert result is None
def test_when_predicate_batch_fires_interrupt_when_true() -> None:
"""`when` returning True allows the tool call to trigger the batch interrupt."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=lambda req: req.tool_call["args"].get("risky", False),
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"risky": True}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
),
):
result = middleware.after_model(state, Runtime())
assert result is not None
def test_when_predicate_receives_correct_args() -> None:
"""The when predicate receives a ToolCallRequest with correct values and a ToolRuntime."""
captured: list[Any] = []
def capture_when(req: ToolCallRequest) -> bool:
captured.append(req)
return True
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"test_tool": InterruptOnConfig(
allowed_decisions=["approve"],
when=capture_when,
)
}
)
ai_message = AIMessage(
content="...",
tool_calls=[{"name": "test_tool", "args": {"val": 42}, "id": "tc-1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
runtime = Runtime()
with (
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value={"decisions": [{"type": "approve"}]},
),
):
middleware.after_model(state, runtime)
assert len(captured) == 1
req = captured[0]
assert req.tool_call["name"] == "test_tool"
assert req.tool_call["args"] == {"val": 42}
assert req.tool is None
assert req.state is state
assert isinstance(req.runtime, ToolRuntime)
assert req.runtime.tool_call_id == "tc-1"
assert req.runtime.state is state
assert req.runtime.context is runtime.context
assert req.runtime.store is runtime.store