feat(langchain): improved HITL patterns (#32996)

# Main changes / new features

## Better support for parallel tool calls

1. Support for multiple tool calls requiring human input
2. Support for combination of tool calls requiring human input + those
that are auto-approved
3. Support structured output w/ tool calls requiring human input
4. Support structured output w/ standard tool calls

## Shortcut for allowed actions

Adds a shortcut where tool config can be specified as a `bool`, meaning
"all actions allowed"

```py
HumanInTheLoopMiddleware(tool_configs={"expensive_tool": True})
```

## A few design decisions here
* We only raise one interrupt w/ all `HumanInterrupt`s, currently we
won't be able to execute all tools until all of these are resolved. This
isn't super blocking bc we can't re-invoke the model until all tools
have finished execution. That being said, if you have a long running
auto-approved tool, this could slow things down.

## TODOs

* Ideally, we would rename `accept` -> `approve`
* Ideally, we would rename `respond` -> `reject`
* Docs update (@sydney-runkle to own)
* In another PR I'd like to refactor testing to have one file for each
prebuilt middleware :)

Fast follow to https://github.com/langchain-ai/langchain/pull/32962
which was deemed as too breaking
This commit is contained in:
Sydney Runkle
2025-09-17 16:53:01 -04:00
committed by GitHub
parent 76d0758007
commit d5ba5d3511
5 changed files with 770 additions and 271 deletions

View File

@@ -1,97 +0,0 @@
"""Interrupt types to use with agent inbox like setups."""
from typing import Literal
from typing_extensions import TypedDict
class HumanInterruptConfig(TypedDict):
"""Configuration that defines what actions are allowed for a human interrupt.
This controls the available interaction options when the graph is paused for human input.
Attributes:
allow_ignore: Whether the human can choose to ignore/skip the current step
allow_respond: Whether the human can provide a text response/feedback
allow_edit: Whether the human can edit the provided content/state
allow_accept: Whether the human can accept/approve the current state
"""
allow_ignore: bool
allow_respond: bool
allow_edit: bool
allow_accept: bool
class ActionRequest(TypedDict):
"""Represents a request for human action within the graph execution.
Contains the action type and any associated arguments needed for the action.
Attributes:
action: The type or name of action being requested (e.g., "Approve XYZ action")
args: Key-value pairs of arguments needed for the action
"""
action: str
args: dict
class HumanInterrupt(TypedDict):
"""Represents an interrupt triggered by the graph that requires human intervention.
This is passed to the `interrupt` function when execution is paused for human input.
Attributes:
action_request: The specific action being requested from the human
config: Configuration defining what actions are allowed
description: Optional detailed description of what input is needed
Example:
```python
# Extract a tool call from the state and create an interrupt request
request = HumanInterrupt(
action_request=ActionRequest(
action="run_command", # The action being requested
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
),
config=HumanInterruptConfig(
allow_ignore=True, # Allow skipping this step
allow_respond=True, # Allow text feedback
allow_edit=False, # Don't allow editing
allow_accept=True, # Allow direct acceptance
),
description="Please review the command before execution",
)
# Send the interrupt request and get the response
response = interrupt([request])[0]
```
"""
action_request: ActionRequest
config: HumanInterruptConfig
description: str | None
class HumanResponse(TypedDict):
"""Human response.
The response provided by a human to an interrupt,
which is returned when graph execution resumes.
Attributes:
type: The type of response:
- "accept": Approves the current state without changes
- "ignore": Skips/ignores the current step
- "response": Provides text feedback or instructions
- "edit": Modifies the current state/content
args: The response payload:
- None: For ignore/accept actions
- str: For text responses
- ActionRequest: For edit actions with updated content
"""
type: Literal["accept", "ignore", "response", "edit"]
args: None | str | ActionRequest

View File

@@ -1,11 +0,0 @@
"""Utility functions for middleware."""
from typing import Any
def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]:
"""Generate tool messages for model behavior correction."""
return [
{"role": "tool", "content": content, "tool_call_id": tool_call["id"]}
for tool_call in tool_calls
]

View File

@@ -1,19 +1,110 @@
"""Human in the loop middleware.""" """Human in the loop middleware."""
from typing import Any from typing import Any, Literal
from langgraph.prebuilt.interrupt import ( from langchain_core.messages import AIMessage, ToolCall, ToolMessage
ActionRequest,
HumanInterrupt,
HumanInterruptConfig,
HumanResponse,
)
from langgraph.types import interrupt from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware._utils import _generate_correction_tool_messages
from langchain.agents.middleware.types import AgentMiddleware, AgentState from langchain.agents.middleware.types import AgentMiddleware, AgentState
ToolInterruptConfig = dict[str, HumanInterruptConfig]
class HumanInTheLoopConfig(TypedDict):
"""Configuration that defines what actions are allowed for a human interrupt.
This controls the available interaction options when the graph is paused for human input.
"""
allow_accept: NotRequired[bool]
"""Whether the human can approve the current action without changes."""
allow_edit: NotRequired[bool]
"""Whether the human can approve the current action with edited content."""
allow_respond: NotRequired[bool]
"""Whether the human can reject the current action with feedback."""
class ActionRequest(TypedDict):
"""Represents a request with a name and arguments."""
action: str
"""The type or name of action being requested (e.g., "add_numbers")."""
args: dict
"""Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2})."""
class HumanInTheLoopRequest(TypedDict):
"""Represents an interrupt triggered by the graph that requires human intervention.
Example:
```python
# Extract a tool call from the state and create an interrupt request
request = HumanInterrupt(
action_request=ActionRequest(
action="run_command", # The action being requested
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
),
config=HumanInTheLoopConfig(
allow_accept=True, # Allow approval
allow_respond=True, # Allow rejection with feedback
allow_edit=False, # Don't allow approval with edits
),
description="Please review the command before execution",
)
# Send the interrupt request and get the response
response = interrupt([request])[0]
```
"""
action_request: ActionRequest
"""The specific action being requested from the human."""
config: HumanInTheLoopConfig
"""Configuration defining what response types are allowed."""
description: str | None
"""Optional detailed description of what input is needed."""
class AcceptPayload(TypedDict):
"""Response when a human approves the action."""
type: Literal["accept"]
"""The type of response when a human approves the action."""
class ResponsePayload(TypedDict):
"""Response when a human rejects the action."""
type: Literal["response"]
"""The type of response when a human rejects the action."""
args: NotRequired[str]
"""The message to be sent to the model explaining why the action was rejected."""
class EditPayload(TypedDict):
"""Response when a human edits the action."""
type: Literal["edit"]
"""The type of response when a human edits the action."""
args: ActionRequest
"""The action request with the edited content."""
HumanInTheLoopResponse = AcceptPayload | ResponsePayload | EditPayload
"""Aggregated response type for all possible human in the loop responses."""
class ToolConfig(TypedDict):
"""Configuration for a tool requiring human in the loop."""
allow_accept: NotRequired[bool]
"""Whether the human can approve the current action without changes."""
allow_edit: NotRequired[bool]
"""Whether the human can approve the current action with edited content."""
allow_respond: NotRequired[bool]
"""Whether the human can reject the current action with feedback."""
description: NotRequired[str]
"""The description attached to the request for human input."""
class HumanInTheLoopMiddleware(AgentMiddleware): class HumanInTheLoopMiddleware(AgentMiddleware):
@@ -21,18 +112,36 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
def __init__( def __init__(
self, self,
tool_configs: ToolInterruptConfig, tool_configs: dict[str, bool | ToolConfig],
message_prefix: str = "Tool execution requires approval", *,
description_prefix: str = "Tool execution requires approval",
) -> None: ) -> None:
"""Initialize the human in the loop middleware. """Initialize the human in the loop middleware.
Args: Args:
tool_configs: The tool interrupt configs to use for the middleware. tool_configs: Mapping of tool name to allowed actions.
message_prefix: The message prefix to use when constructing interrupt content. If a tool doesn't have an entry, it's auto-approved by default.
* `True` indicates all actions are allowed: accept, edit, and respond.
* `False` indicates that the tool is auto-approved.
* ToolConfig indicates the specific actions allowed for this tool.
description_prefix: The prefix to use when constructing action requests.
This is used to provide context about the tool call and the action being requested.
Not used if a tool has a description in its ToolConfig.
""" """
super().__init__() super().__init__()
self.tool_configs = tool_configs resolved_tool_configs: dict[str, ToolConfig] = {}
self.message_prefix = message_prefix for tool_name, tool_config in tool_configs.items():
if isinstance(tool_config, bool):
if tool_config is True:
resolved_tool_configs[tool_name] = ToolConfig(
allow_accept=True,
allow_edit=True,
allow_respond=True,
)
else:
resolved_tool_configs[tool_name] = tool_config
self.tool_configs = resolved_tool_configs
self.description_prefix = description_prefix
def after_model(self, state: AgentState) -> dict[str, Any] | None: def after_model(self, state: AgentState) -> dict[str, Any] | None:
"""Trigger HITL flows for relevant tool calls after an AIMessage.""" """Trigger HITL flows for relevant tool calls after an AIMessage."""
@@ -40,96 +149,105 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
if not messages: if not messages:
return None return None
last_message = messages[-1] last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None)
if not last_ai_msg or not last_ai_msg.tool_calls:
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None return None
# Separate tool calls that need interrupts from those that don't # Separate tool calls that need interrupts from those that don't
interrupt_tool_calls = [] hitl_tool_calls: list[ToolCall] = []
auto_approved_tool_calls = [] auto_approved_tool_calls = []
for tool_call in last_message.tool_calls: for tool_call in last_ai_msg.tool_calls:
tool_name = tool_call["name"] hitl_tool_calls.append(tool_call) if tool_call[
if tool_name in self.tool_configs: "name"
interrupt_tool_calls.append(tool_call) ] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
else:
auto_approved_tool_calls.append(tool_call)
# If no interrupts needed, return early # If no interrupts needed, return early
if not interrupt_tool_calls: if not hitl_tool_calls:
return None return None
approved_tool_calls = auto_approved_tool_calls.copy() # Process all tool calls that require interrupts
approved_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
artificial_tool_messages: list[ToolMessage] = []
# Right now, we do not support multiple tool calls with interrupts # Create interrupt requests for all tools that need approval
if len(interrupt_tool_calls) > 1: hitl_requests: list[HumanInTheLoopRequest] = []
tool_names = [t["name"] for t in interrupt_tool_calls] for tool_call in hitl_tool_calls:
msg = (
f"Called the following tools which require interrupts: {tool_names}\n\n"
"You may only call ONE tool that requires an interrupt at a time"
)
return {
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
"jump_to": "model",
}
# Right now, we do not support interrupting a tool call if other tool calls exist
if auto_approved_tool_calls:
tool_names = [t["name"] for t in interrupt_tool_calls]
msg = (
f"Called the following tools which require interrupts: {tool_names}. "
"You also called other tools that do not require interrupts. "
"If you call a tool that requires and interrupt, you may ONLY call that tool."
)
return {
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
"jump_to": "model",
}
# Only one tool call will need interrupts
tool_call = interrupt_tool_calls[0]
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = tool_call["args"] tool_args = tool_call["args"]
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" config = self.tool_configs[tool_name]
tool_config = self.tool_configs[tool_name] description = (
config.get("description")
or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
)
request: HumanInterrupt = { request: HumanInTheLoopRequest = {
"action_request": ActionRequest( "action_request": ActionRequest(
action=tool_name, action=tool_name,
args=tool_args, args=tool_args,
), ),
"config": tool_config, "config": config,
"description": description, "description": description,
} }
hitl_requests.append(request)
responses: list[HumanResponse] = interrupt([request]) responses: list[HumanInTheLoopResponse] = interrupt(hitl_requests)
response = responses[0]
if response["type"] == "accept": # Validate that the number of responses matches the number of interrupt tool calls
approved_tool_calls.append(tool_call) if (responses_len := len(responses)) != (hitl_tool_calls_len := len(hitl_tool_calls)):
elif response["type"] == "edit": msg = (
edited: ActionRequest = response["args"] # type: ignore[assignment] f"Number of human responses ({responses_len}) does not match "
new_tool_call = { f"number of hanging tool calls ({hitl_tool_calls_len})."
"type": "tool_call", )
"name": tool_call["name"],
"args": edited["args"],
"id": tool_call["id"],
}
approved_tool_calls.append(new_tool_call)
elif response["type"] == "ignore":
return {"jump_to": "__end__"}
elif response["type"] == "response":
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": response["args"],
}
return {"messages": [tool_message], "jump_to": "model"}
else:
msg = f"Unknown response type: {response['type']}"
raise ValueError(msg) raise ValueError(msg)
last_message.tool_calls = approved_tool_calls for i, response in enumerate(responses):
tool_call = hitl_tool_calls[i]
config = self.tool_configs[tool_call["name"]]
return {"messages": [last_message]} if response["type"] == "accept" and config.get("allow_accept"):
approved_tool_calls.append(tool_call)
elif response["type"] == "edit" and config.get("allow_edit"):
edited_action = response["args"]
approved_tool_calls.append(
ToolCall(
type="tool_call",
name=edited_action["action"],
args=edited_action["args"],
id=tool_call["id"],
)
)
elif response["type"] == "response" and config.get("allow_respond"):
# Create a tool message with the human's text response
content = response.get("args") or (
f"User rejected the tool call for `{tool_call['name']}` "
f"with id {tool_call['id']}"
)
tool_message = ToolMessage(
content=content,
name=tool_call["name"],
tool_call_id=tool_call["id"],
status="error",
)
artificial_tool_messages.append(tool_message)
else:
allowed_actions = [
action
for action in ["accept", "edit", "response"]
if config.get(f"allow_{'respond' if action == 'response' else action}")
]
msg = (
f"Unexpected human response: {response}. "
f"Response action '{response.get('type')}' "
f"is not allowed for tool '{tool_call['name']}'. "
f"Expected one of {allowed_actions} based on the tool's configuration."
)
raise ValueError(msg)
# Update the AI message to only include approved tool calls
last_ai_msg.tool_calls = approved_tool_calls
if len(approved_tool_calls) > 0:
return {"messages": [last_ai_msg, *artificial_tool_messages]}
return {"jump_to": "model", "messages": artificial_tool_messages}

View File

@@ -2,14 +2,15 @@
import itertools import itertools
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any from typing import Any, cast
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.constants import END, START from langgraph.constants import END, START
from langgraph.graph.state import StateGraph from langgraph.graph.state import StateGraph
from langgraph.types import Send
from langgraph.typing import ContextT from langgraph.typing import ContextT
from typing_extensions import TypedDict, TypeVar from typing_extensions import TypedDict, TypeVar
@@ -211,15 +212,13 @@ def create_agent( # noqa: PLR0915
context_schema=context_schema, context_schema=context_schema,
) )
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]: def _handle_model_output(output: AIMessage) -> dict[str, Any]:
"""Handle model output including structured responses.""" """Handle model output including structured responses."""
# Handle structured output with native strategy # Handle structured output with native strategy
if isinstance(response_format, ProviderStrategy): if isinstance(response_format, ProviderStrategy):
if not output.tool_calls and native_output_binding: if not output.tool_calls and native_output_binding:
structured_response = native_output_binding.parse(output) structured_response = native_output_binding.parse(output)
return {"messages": [output], "response": structured_response} return {"messages": [output], "response": structured_response}
if state.get("response") is not None:
return {"messages": [output], "response": None}
return {"messages": [output]} return {"messages": [output]}
# Handle structured output with tools strategy # Handle structured output with tools strategy
@@ -297,9 +296,6 @@ def create_agent( # noqa: PLR0915
], ],
} }
# Standard response handling
if state.get("response") is not None:
return {"messages": [output], "response": None}
return {"messages": [output]} return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> Runnable: def _get_bound_model(request: ModelRequest) -> Runnable:
@@ -346,7 +342,7 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages] messages = [SystemMessage(request.system_prompt), *messages]
output = model_.invoke(messages) output = model_.invoke(messages)
return _handle_model_output(state, output) return _handle_model_output(output)
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]: async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing.""" """Async model request handler with sequential middleware processing."""
@@ -373,7 +369,7 @@ def create_agent( # noqa: PLR0915
messages = [SystemMessage(request.system_prompt), *messages] messages = [SystemMessage(request.system_prompt), *messages]
output = await model_.ainvoke(messages) output = await model_.ainvoke(messages)
return _handle_model_output(state, output) return _handle_model_output(output)
# Use sync or async based on model capabilities # Use sync or async based on model capabilities
from langgraph._internal._runnable import RunnableCallable from langgraph._internal._runnable import RunnableCallable
@@ -417,12 +413,12 @@ def create_agent( # noqa: PLR0915
if tool_node is not None: if tool_node is not None:
graph.add_conditional_edges( graph.add_conditional_edges(
"tools", "tools",
_make_tools_to_model_edge(tool_node, first_node), _make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
[first_node, END], [first_node, END],
) )
graph.add_conditional_edges( graph.add_conditional_edges(
last_node, last_node,
_make_model_to_tools_edge(first_node, structured_output_tools), _make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
[first_node, "tools", END], [first_node, "tools", END],
) )
elif last_node == "model_request": elif last_node == "model_request":
@@ -481,27 +477,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
return None return None
def _fetch_last_ai_and_tool_messages(
messages: list[AnyMessage],
) -> tuple[AIMessage, list[ToolMessage]]:
last_ai_index: int
last_ai_message: AIMessage
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], AIMessage):
last_ai_index = i
last_ai_message = cast("AIMessage", messages[i])
break
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
return last_ai_message, tool_messages
def _make_model_to_tools_edge( def _make_model_to_tools_edge(
first_node: str, structured_output_tools: dict[str, OutputToolBinding] first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
) -> Callable[[AgentState], str | None]: ) -> Callable[[AgentState], str | list[Send] | None]:
def model_to_tools(state: AgentState) -> str | None: def model_to_tools(state: AgentState) -> str | list[Send] | None:
if jump_to := state.get("jump_to"): if jump_to := state.get("jump_to"):
return _resolve_jump(jump_to, first_node) return _resolve_jump(jump_to, first_node)
message = state["messages"][-1] last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
tool_message_ids = [m.tool_call_id for m in tool_messages]
# Check if this is a ToolMessage from structured output - if so, end pending_tool_calls = [
if isinstance(message, ToolMessage) and message.name in structured_output_tools: c
return END for c in last_ai_message.tool_calls
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
# Check for tool calls
if isinstance(message, AIMessage) and message.tool_calls:
# If all tool calls are for structured output, don't go to tools
non_structured_calls = [
tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
] ]
if non_structured_calls:
return "tools" if pending_tool_calls:
# imo we should not be injecting state, store here,
# this should be done by the tool node itself ideally but this is a consequence
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
# largely internal so can be fixed later
pending_tool_calls = [
tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
for call in pending_tool_calls
]
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
return END return END
@@ -509,17 +526,21 @@ def _make_model_to_tools_edge(
def _make_tools_to_model_edge( def _make_tools_to_model_edge(
tool_node: ToolNode, next_node: str tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
) -> Callable[[AgentState], str | None]: ) -> Callable[[AgentState], str | None]:
def tools_to_model(state: AgentState) -> str | None: def tools_to_model(state: AgentState) -> str | None:
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1] last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
if all( if all(
tool_node.tools_by_name[c["name"]].return_direct tool_node.tools_by_name[c["name"]].return_direct
for c in ai_message.tool_calls for c in last_ai_message.tool_calls
if c["name"] in tool_node.tools_by_name if c["name"] in tool_node.tools_by_name
): ):
return END return END
if any(t.name in structured_output_tools for t in tool_messages):
return END
return next_node return next_node
return tools_to_model return tools_to_model

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from pydantic import BaseModel, Field
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import ( from langchain_core.messages import (
@@ -14,19 +15,23 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware,
HumanInTheLoopConfig,
ActionRequest,
)
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
from langchain_core.tools import BaseTool
from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END from langgraph.constants import END
from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.prebuilt.interrupt import ActionRequest, HumanInterruptConfig from langchain.agents.structured_output import ToolStrategy
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel from .model import FakeToolCallingModel
@@ -355,27 +360,28 @@ def test_create_agent_jump(
# Tests for HumanInTheLoopMiddleware # Tests for HumanInTheLoopMiddleware
def test_human_in_the_loop_middleware_initialization() -> None: def test_human_in_the_loop_middleware_initialization() -> None:
"""Test HumanInTheLoopMiddleware initialization.""" """Test HumanInTheLoopMiddleware initialization."""
middleware = HumanInTheLoopMiddleware(
tool_configs={ tool_configs={
"test_tool": HumanInterruptConfig( "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True },
description_prefix="Custom prefix",
) )
assert middleware.tool_configs == {
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
} }
assert middleware.description_prefix == "Custom prefix"
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs, message_prefix="Custom prefix")
assert middleware.tool_configs == tool_configs
assert middleware.message_prefix == "Custom prefix"
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None: def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
"""Test HumanInTheLoopMiddleware when no interrupts are needed.""" """Test HumanInTheLoopMiddleware when no interrupts are needed."""
tool_configs = {
"test_tool": HumanInterruptConfig(
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
)
}
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs) middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
# Test with no messages # Test with no messages
state: dict[str, Any] = {"messages": []} state: dict[str, Any] = {"messages": []}
@@ -397,15 +403,14 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
assert result is None assert result is None
def test_human_in_the_loop_middleware_interrupt_responses() -> None: def test_human_in_the_loop_middleware_single_tool_accept() -> None:
"""Test HumanInTheLoopMiddleware with different interrupt response types.""" """Test HumanInTheLoopMiddleware with single tool accept response."""
tool_configs = {
"test_tool": HumanInterruptConfig(
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
)
}
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs) middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage( ai_message = AIMessage(
content="I'll help you", content="I'll help you",
@@ -413,55 +418,443 @@ def test_human_in_the_loop_middleware_interrupt_responses() -> None:
) )
state = {"messages": [HumanMessage(content="Hello"), ai_message]} state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test accept response
def mock_accept(requests): def mock_accept(requests):
return [{"type": "accept", "args": None}] return [{"type": "accept", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept): with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state) result = middleware.after_model(state)
assert result is not None assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0] == ai_message assert result["messages"][0] == ai_message
assert result["messages"][0].tool_calls == ai_message.tool_calls assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit response
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
"""Test HumanInTheLoopMiddleware with single tool edit response."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit(requests): def mock_edit(requests):
return [ return [
{"type": "edit", "args": ActionRequest(action="test_tool", args={"input": "edited"})} {
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "edited"},
),
}
] ]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit): with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
result = middleware.after_model(state) result = middleware.after_model(state)
assert result is not None assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"} assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
# Test ignore response
def mock_ignore(requests):
return [{"type": "ignore", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore): def test_human_in_the_loop_middleware_single_tool_response() -> None:
result = middleware.after_model(state) """Test HumanInTheLoopMiddleware with single tool response with custom message."""
assert result is not None
assert result["jump_to"] == "__end__" middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test response type
def mock_response(requests): def mock_response(requests):
return [{"type": "response", "args": "Custom response"}] return [{"type": "response", "args": "Custom response message"}]
with patch( with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response "langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
): ):
result = middleware.after_model(state) result = middleware.after_model(state)
assert result is not None assert result is not None
assert result["jump_to"] == "model" assert "messages" in result
assert result["messages"][0]["role"] == "tool" assert len(result["messages"]) == 1 # Only tool message when no approved tool calls
assert result["messages"][0]["content"] == "Custom response" assert isinstance(result["messages"][0], ToolMessage)
assert result["messages"][0].content == "Custom response message"
assert result["messages"][0].name == "test_tool"
assert result["messages"][0].tool_call_id == "1"
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_mixed_responses(requests):
return [
{"type": "accept", "args": None},
{"type": "response", "args": "User rejected this tool call"},
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert (
len(result["messages"]) == 2
) # AI message with accepted tool call + tool message for rejected
# First message should be the AI message with updated tool calls
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 1 # Only accepted tool call
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
# Second message should be the tool message for the rejected tool call
tool_message = result["messages"][1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "User rejected this tool call"
assert tool_message.name == "get_temperature"
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
}
)
ai_message = AIMessage(
content="I'll help you with weather",
tool_calls=[
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
def mock_edit_responses(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="get_forecast",
args={"location": "New York"},
),
},
{
"type": "edit",
"args": ActionRequest(
action="get_temperature",
args={"location": "New York"},
),
},
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_edit_with_args(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "modified"},
),
}
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_edit_with_args,
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
# Should have modified args
updated_ai_message = result["messages"][0]
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
"""Test HumanInTheLoopMiddleware with unknown response type."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test unknown response type
def mock_unknown(requests): def mock_unknown(requests):
return [{"type": "unknown", "args": None}] return [{"type": "unknown", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown): with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
with pytest.raises(ValueError, match="Unknown response type: unknown"): with pytest.raises(
ValueError,
match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.",
):
middleware.after_model(state)
def test_human_in_the_loop_middleware_disallowed_action() -> None:
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
# edit is not allowed by tool config
middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_disallowed_action(requests):
return [
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "modified"},
),
}
]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_disallowed_action,
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.",
):
middleware.after_model(state)
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
}
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return [{"type": "accept", "args": None}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
updated_ai_message = result["messages"][0]
# Should have both tools: auto-approved first, then interrupt tool
assert len(updated_ai_message.tool_calls) == 2
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
"""Test that interrupt requests are structured correctly."""
middleware = HumanInTheLoopMiddleware(
tool_configs={
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
},
description_prefix="Custom prefix",
)
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
captured_requests = []
def mock_capture_requests(requests):
captured_requests.extend(requests)
return [{"type": "accept", "args": None}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
):
middleware.after_model(state)
assert len(captured_requests) == 1
request = captured_requests[0]
assert "action_request" in request
assert "config" in request
assert "description" in request
assert request["action_request"]["action"] == "test_tool"
assert request["action_request"]["args"] == {"input": "test", "location": "SF"}
expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True}
assert request["config"] == expected_config
assert "Custom prefix" in request["description"]
assert "Tool: test_tool" in request["description"]
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
def test_human_in_the_loop_middleware_boolean_configs() -> None:
"""Test HITL middleware with boolean tool configs."""
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test accept
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "accept", "args": None}],
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[
{
"type": "edit",
"args": ActionRequest(
action="test_tool",
args={"input": "edited"},
),
}
],
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False})
result = middleware.after_model(state)
# No interruption should occur
assert result is None
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
"""Test that sequence mismatch in resume raises an error."""
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test with too few responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[], # No responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state)
# Test with too many responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[
{"type": "accept", "args": None},
{"type": "accept", "args": None},
], # 2 responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state) middleware.after_model(state)
@@ -733,3 +1126,78 @@ def test_modify_model_request() -> None:
assert ( assert (
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!" result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
) )
def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
"""Test that when there are both structured and regular tool calls, we execute regular and jump to END."""
class WeatherResponse(BaseModel):
"""Weather response."""
temperature: float = Field(description="Temperature in fahrenheit")
condition: str = Field(description="Weather condition")
@tool
def regular_tool(query: str) -> str:
"""A regular tool that returns a string."""
return f"Regular tool result for: {query}"
# Create a fake model that returns both structured and regular tool calls
class FakeModelWithBothToolCalls(FakeToolCallingModel):
def __init__(self):
super().__init__()
self.tool_calls = [
[
ToolCall(
name="WeatherResponse",
args={"temperature": 72.0, "condition": "sunny"},
id="structured_call_1",
),
ToolCall(
name="regular_tool", args={"query": "test query"}, id="regular_call_1"
),
]
]
# Create agent with both structured output and regular tools
agent = create_agent(
model=FakeModelWithBothToolCalls(),
tools=[regular_tool],
response_format=ToolStrategy(schema=WeatherResponse),
)
# Compile and invoke the agent
compiled_agent = agent.compile()
result = compiled_agent.invoke(
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
)
# Verify that we have the expected messages:
# 1. Human message
# 2. AI message with both tool calls
# 3. Tool message from structured tool call
# 4. Tool message from regular tool call
messages = result["messages"]
assert len(messages) >= 4
# Check that we have the AI message with both tool calls
ai_message = messages[1]
assert isinstance(ai_message, AIMessage)
assert len(ai_message.tool_calls) == 2
# Check that we have a tool message from the regular tool
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) >= 1
# The regular tool should have been executed
regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None)
assert regular_tool_message is not None
assert "Regular tool result for: test query" in regular_tool_message.content
# Verify that the structured response is available in the result
assert "response" in result
assert result["response"] is not None
assert hasattr(result["response"], "temperature")
assert result["response"].temperature == 72.0
assert result["response"].condition == "sunny"