mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(langchain): add interrupt_mode and when predicate to HumanInTheLoopMiddleware (#37579)
Adds an optional `when` predicate to `InterruptOnConfig`, allowing
dynamic per-tool-call control over whether a HITL interrupt fires.
---
**`when` predicate in `InterruptOnConfig`**
```python
class InterruptOnConfig(TypedDict):
allowed_decisions: list[DecisionType]
description: NotRequired[str | _DescriptionFactory]
args_schema: NotRequired[dict[str, Any]]
when: NotRequired[Callable[[ToolCallRequest], bool]] # new
```
When provided, `when` is called before adding a tool call to the batch
interrupt. If it returns `False`, the call is auto-approved and
excluded. If it returns `True` (or `when` is absent), existing behaviour
is unchanged.
The predicate receives a `ToolCallRequest` with:
- `tool_call` — the raw tool call dict (name, args, id)
- `tool` — `None` (no `BaseTool` instance is available at the
`after_model` stage)
- `state` — current agent state
- `runtime` — a `ToolRuntime` constructed from the node-level `Runtime`,
with `tool_call_id` populated
Example:
```python
HumanInTheLoopMiddleware(
interrupt_on={
"delete_file": InterruptOnConfig(
allowed_decisions=["approve", "reject"],
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
)
}
)
```
This change is fully backwards-compatible — `when` is `NotRequired` and
existing configs without it behave identically.
> This PR was developed with AI-agent assistance.
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
"""Human in the loop middleware."""
|
"""Human in the loop middleware."""
|
||||||
|
|
||||||
from typing import Any, Literal, Protocol
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Protocol
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.config import get_config
|
||||||
|
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||||
from langgraph.types import interrupt
|
from langgraph.types import interrupt
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
@@ -13,8 +16,14 @@ from langchain.agents.middleware.types import (
|
|||||||
ContextT,
|
ContextT,
|
||||||
ResponseT,
|
ResponseT,
|
||||||
StateT,
|
StateT,
|
||||||
|
ToolCallRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
class Action(TypedDict):
|
class Action(TypedDict):
|
||||||
"""Represents an action with a name and args."""
|
"""Represents an action with a name and args."""
|
||||||
@@ -178,6 +187,27 @@ class InterruptOnConfig(TypedDict):
|
|||||||
args_schema: NotRequired[dict[str, Any]]
|
args_schema: NotRequired[dict[str, Any]]
|
||||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||||
|
|
||||||
|
when: NotRequired[Callable[[ToolCallRequest], bool]]
|
||||||
|
"""Optional predicate controlling whether to interrupt for a given tool call.
|
||||||
|
|
||||||
|
Receives a `ToolCallRequest` and returns `True` to interrupt or `False` to
|
||||||
|
auto-approve. Works in both `"batch"` and `"per_call"` modes.
|
||||||
|
|
||||||
|
In `"batch"` mode the request is constructed with `tool=None` and
|
||||||
|
`runtime` set to the node-level `Runtime` (not a `ToolRuntime`), so
|
||||||
|
`request.runtime.tool_call_id` and `request.runtime.tools` are not available.
|
||||||
|
In `"per_call"` mode the full `ToolCallRequest` from `wrap_tool_call` is passed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Only interrupt delete_file calls targeting /etc
|
||||||
|
config = InterruptOnConfig(
|
||||||
|
allowed_decisions=["approve", "reject"],
|
||||||
|
when=lambda req: req.tool_call["args"].get("path", "").startswith("/etc"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||||
"""Human in the loop middleware."""
|
"""Human in the loop middleware."""
|
||||||
@@ -203,6 +233,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
|||||||
|
|
||||||
The `InterruptOnConfig` can include a `description` field (`str` or
|
The `InterruptOnConfig` can include a `description` field (`str` or
|
||||||
`Callable`) for custom formatting of the interrupt description.
|
`Callable`) for custom formatting of the interrupt description.
|
||||||
|
|
||||||
|
A `when` predicate can also be provided to dynamically control
|
||||||
|
whether a tool call triggers an interrupt.
|
||||||
description_prefix: The prefix to use when constructing action requests.
|
description_prefix: The prefix to use when constructing action requests.
|
||||||
|
|
||||||
This is used to provide context about the tool call and the action being
|
This is used to provide context about the tool call and the action being
|
||||||
@@ -310,6 +343,39 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def _should_interrupt(
|
||||||
|
self,
|
||||||
|
tool_call: ToolCall,
|
||||||
|
config: InterruptOnConfig,
|
||||||
|
state: AgentState[Any],
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> bool:
|
||||||
|
"""Return False if the `when` predicate rejects this tool call, True otherwise."""
|
||||||
|
when = config.get("when")
|
||||||
|
if when is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
runnable_config = get_config()
|
||||||
|
except RuntimeError:
|
||||||
|
runnable_config = {}
|
||||||
|
tool_runtime = ToolRuntime(
|
||||||
|
state=state,
|
||||||
|
context=runtime.context,
|
||||||
|
config=runnable_config,
|
||||||
|
stream_writer=runtime.stream_writer,
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
store=runtime.store,
|
||||||
|
execution_info=runtime.execution_info,
|
||||||
|
server_info=runtime.server_info,
|
||||||
|
)
|
||||||
|
req = ToolCallRequest(
|
||||||
|
tool_call=tool_call,
|
||||||
|
tool=None,
|
||||||
|
state=state,
|
||||||
|
runtime=tool_runtime, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return when(req)
|
||||||
|
|
||||||
def after_model(
|
def after_model(
|
||||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
@@ -341,6 +407,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
|||||||
|
|
||||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||||
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||||
|
if not self._should_interrupt(tool_call, config, state, runtime):
|
||||||
|
continue
|
||||||
action_request, review_config = self._create_action_and_config(
|
action_request, review_config = self._create_action_and_config(
|
||||||
tool_call, config, state, runtime
|
tool_call, config, state, runtime
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||||
|
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from langchain.agents.middleware import InterruptOnConfig
|
from langchain.agents.middleware import InterruptOnConfig
|
||||||
@@ -11,7 +12,7 @@ from langchain.agents.middleware.human_in_the_loop import (
|
|||||||
Action,
|
Action,
|
||||||
HumanInTheLoopMiddleware,
|
HumanInTheLoopMiddleware,
|
||||||
)
|
)
|
||||||
from langchain.agents.middleware.types import AgentState
|
from langchain.agents.middleware.types import AgentState, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||||
@@ -883,3 +884,107 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
|||||||
assert isinstance(tool_message, ToolMessage)
|
assert isinstance(tool_message, ToolMessage)
|
||||||
assert tool_message.content == "Rejected tool B"
|
assert tool_message.content == "Rejected tool B"
|
||||||
assert tool_message.tool_call_id == "id_b"
|
assert tool_message.tool_call_id == "id_b"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# when predicate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_when_predicate_batch_skips_interrupt_when_false() -> None:
|
||||||
|
"""`when` returning False prevents the tool call from joining the batch interrupt."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"test_tool": InterruptOnConfig(
|
||||||
|
allowed_decisions=["approve"],
|
||||||
|
when=lambda req: req.tool_call["args"].get("risky", False),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="...",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"risky": False}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||||
|
patch("langchain.agents.middleware.human_in_the_loop.interrupt") as mock_interrupt,
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state, Runtime())
|
||||||
|
mock_interrupt.assert_not_called()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_when_predicate_batch_fires_interrupt_when_true() -> None:
|
||||||
|
"""`when` returning True allows the tool call to trigger the batch interrupt."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"test_tool": InterruptOnConfig(
|
||||||
|
allowed_decisions=["approve"],
|
||||||
|
when=lambda req: req.tool_call["args"].get("risky", False),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="...",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"risky": True}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||||
|
patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value={"decisions": [{"type": "approve"}]},
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state, Runtime())
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_when_predicate_receives_correct_args() -> None:
|
||||||
|
"""The when predicate receives a ToolCallRequest with correct values and a ToolRuntime."""
|
||||||
|
captured: list[Any] = []
|
||||||
|
|
||||||
|
def capture_when(req: ToolCallRequest) -> bool:
|
||||||
|
captured.append(req)
|
||||||
|
return True
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"test_tool": InterruptOnConfig(
|
||||||
|
allowed_decisions=["approve"],
|
||||||
|
when=capture_when,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="...",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"val": 42}, "id": "tc-1"}],
|
||||||
|
)
|
||||||
|
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])
|
||||||
|
runtime = Runtime()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("langchain.agents.middleware.human_in_the_loop.get_config", return_value={}),
|
||||||
|
patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value={"decisions": [{"type": "approve"}]},
|
||||||
|
),
|
||||||
|
):
|
||||||
|
middleware.after_model(state, runtime)
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
req = captured[0]
|
||||||
|
assert req.tool_call["name"] == "test_tool"
|
||||||
|
assert req.tool_call["args"] == {"val": 42}
|
||||||
|
assert req.tool is None
|
||||||
|
assert req.state is state
|
||||||
|
assert isinstance(req.runtime, ToolRuntime)
|
||||||
|
assert req.runtime.tool_call_id == "tc-1"
|
||||||
|
assert req.runtime.state is state
|
||||||
|
assert req.runtime.context is runtime.context
|
||||||
|
assert req.runtime.store is runtime.store
|
||||||
|
|||||||
Reference in New Issue
Block a user