mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(langchain): add interrupt_mode and when predicate to HumanInTheLoopMiddleware (#37579)
Adds an optional `when` predicate to `InterruptOnConfig`, allowing
dynamic per-tool-call control over whether a HITL interrupt fires.
---
**`when` predicate in `InterruptOnConfig`**
```python
class InterruptOnConfig(TypedDict):
allowed_decisions: list[DecisionType]
description: NotRequired[str | _DescriptionFactory]
args_schema: NotRequired[dict[str, Any]]
when: NotRequired[Callable[[ToolCallRequest], bool]] # new
```
When provided, `when` is called before adding a tool call to the batch
interrupt. If it returns `False`, the call is auto-approved and
excluded. If it returns `True` (or `when` is absent), existing behaviour
is unchanged.
The predicate receives a `ToolCallRequest` with:
- `tool_call` — the raw tool call dict (name, args, id)
- `tool` — `None` (no `BaseTool` instance is available at the
`after_model` stage)
- `state` — current agent state
- `runtime` — a `ToolRuntime` constructed from the node-level `Runtime`,
with `tool_call_id` populated
Example:
```python
HumanInTheLoopMiddleware(
interrupt_on={
"delete_file": InterruptOnConfig(
allowed_decisions=["approve", "reject"],
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
)
}
)
```
This change is fully backwards-compatible — `when` is `NotRequired` and
existing configs without it behave identically.
> This PR was developed with AI-agent assistance.
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
from typing import Any, Literal, Protocol
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_config
|
||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
@@ -13,8 +16,14 @@ from langchain.agents.middleware.types import (
|
||||
ContextT,
|
||||
ResponseT,
|
||||
StateT,
|
||||
ToolCallRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
class Action(TypedDict):
|
||||
"""Represents an action with a name and args."""
|
||||
@@ -178,6 +187,27 @@ class InterruptOnConfig(TypedDict):
|
||||
args_schema: NotRequired[dict[str, Any]]
|
||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||
|
||||
when: NotRequired[Callable[[ToolCallRequest], bool]]
|
||||
"""Optional predicate controlling whether to interrupt for a given tool call.
|
||||
|
||||
Receives a `ToolCallRequest` and returns `True` to interrupt or `False` to
|
||||
auto-approve. Works in both `"batch"` and `"per_call"` modes.
|
||||
|
||||
In `"batch"` mode the request is constructed with `tool=None` and
|
||||
`runtime` set to the node-level `Runtime` (not a `ToolRuntime`), so
|
||||
`request.runtime.tool_call_id` and `request.runtime.tools` are not available.
|
||||
In `"per_call"` mode the full `ToolCallRequest` from `wrap_tool_call` is passed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Only interrupt delete_file calls targeting /etc
|
||||
config = InterruptOnConfig(
|
||||
allowed_decisions=["approve", "reject"],
|
||||
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||
"""Human in the loop middleware."""
|
||||
@@ -203,6 +233,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||
|
||||
The `InterruptOnConfig` can include a `description` field (`str` or
|
||||
`Callable`) for custom formatting of the interrupt description.
|
||||
|
||||
A `when` predicate can also be provided to dynamically control
|
||||
whether a tool call triggers an interrupt.
|
||||
description_prefix: The prefix to use when constructing action requests.
|
||||
|
||||
This is used to provide context about the tool call and the action being
|
||||
@@ -310,6 +343,39 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def _should_interrupt(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
config: InterruptOnConfig,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> bool:
|
||||
"""Return False if the `when` predicate rejects this tool call, True otherwise."""
|
||||
when = config.get("when")
|
||||
if when is None:
|
||||
return True
|
||||
try:
|
||||
runnable_config = get_config()
|
||||
except RuntimeError:
|
||||
runnable_config = {}
|
||||
tool_runtime = ToolRuntime(
|
||||
state=state,
|
||||
context=runtime.context,
|
||||
config=runnable_config,
|
||||
stream_writer=runtime.stream_writer,
|
||||
tool_call_id=tool_call["id"],
|
||||
store=runtime.store,
|
||||
execution_info=runtime.execution_info,
|
||||
server_info=runtime.server_info,
|
||||
)
|
||||
req = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=None,
|
||||
state=state,
|
||||
runtime=tool_runtime, # type: ignore[arg-type]
|
||||
)
|
||||
return when(req)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
@@ -341,6 +407,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||
if not self._should_interrupt(tool_call, config, state, runtime):
|
||||
continue
|
||||
action_request, review_config = self._create_action_and_config(
|
||||
tool_call, config, state, runtime
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware import InterruptOnConfig
|
||||
@@ -11,7 +12,7 @@ from langchain.agents.middleware.human_in_the_loop import (
|
||||
Action,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
from langchain.agents.middleware.types import AgentState, ToolCallRequest
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
@@ -883,3 +884,107 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "Rejected tool B"
|
||||
assert tool_message.tool_call_id == "id_b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# when predicate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_when_predicate_batch_skips_interrupt_when_false() -> None:
|
||||
"""`when` returning False prevents the tool call from joining the batch interrupt."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": InterruptOnConfig(
|
||||
allowed_decisions=["approve"],
|
||||
when=lambda req: req.tool_call["args"].get("risky", False),
|
||||
)
|
||||
}
|
||||
)
|
||||
ai_message = AIMessage(
|
||||
content="...",
|
||||
tool_calls=[{"name": "test_tool", "args": {"risky": False}, "id": "1"}],
|
||||
)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||
|
||||
with (
|
||||
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||
patch("langchain.agents.middleware.human_in_the_loop.interrupt") as mock_interrupt,
|
||||
):
|
||||
result = middleware.after_model(state, Runtime())
|
||||
mock_interrupt.assert_not_called()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_when_predicate_batch_fires_interrupt_when_true() -> None:
|
||||
"""`when` returning True allows the tool call to trigger the batch interrupt."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": InterruptOnConfig(
|
||||
allowed_decisions=["approve"],
|
||||
when=lambda req: req.tool_call["args"].get("risky", False),
|
||||
)
|
||||
}
|
||||
)
|
||||
ai_message = AIMessage(
|
||||
content="...",
|
||||
tool_calls=[{"name": "test_tool", "args": {"risky": True}, "id": "1"}],
|
||||
)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||
|
||||
with (
|
||||
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||
patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": [{"type": "approve"}]},
|
||||
),
|
||||
):
|
||||
result = middleware.after_model(state, Runtime())
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_when_predicate_receives_correct_args() -> None:
|
||||
"""The when predicate receives a ToolCallRequest with correct values and a ToolRuntime."""
|
||||
captured: list[Any] = []
|
||||
|
||||
def capture_when(req: ToolCallRequest) -> bool:
|
||||
captured.append(req)
|
||||
return True
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"test_tool": InterruptOnConfig(
|
||||
allowed_decisions=["approve"],
|
||||
when=capture_when,
|
||||
)
|
||||
}
|
||||
)
|
||||
ai_message = AIMessage(
|
||||
content="...",
|
||||
tool_calls=[{"name": "test_tool", "args": {"val": 42}, "id": "tc-1"}],
|
||||
)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||
runtime = Runtime()
|
||||
|
||||
with (
|
||||
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||
patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": [{"type": "approve"}]},
|
||||
),
|
||||
):
|
||||
middleware.after_model(state, runtime)
|
||||
|
||||
assert len(captured) == 1
|
||||
req = captured[0]
|
||||
assert req.tool_call["name"] == "test_tool"
|
||||
assert req.tool_call["args"] == {"val": 42}
|
||||
assert req.tool is None
|
||||
assert req.state is state
|
||||
assert isinstance(req.runtime, ToolRuntime)
|
||||
assert req.runtime.tool_call_id == "tc-1"
|
||||
assert req.runtime.state is state
|
||||
assert req.runtime.context is runtime.context
|
||||
assert req.runtime.store is runtime.store
|
||||
|
||||
Reference in New Issue
Block a user