fix: don't reorder tool calls in HITL middleware (#34023)

This commit is contained in:
Sydney Runkle
2025-11-22 05:10:32 -05:00
committed by GitHub
parent 47b79c30c0
commit 68ab9a1e56
2 changed files with 199 additions and 41 deletions

View File

@@ -287,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
if not last_ai_msg or not last_ai_msg.tool_calls:
return None
# Separate tool calls that need interrupts from those that don't
interrupt_tool_calls: list[ToolCall] = []
auto_approved_tool_calls = []
for tool_call in last_ai_msg.tool_calls:
interrupt_tool_calls.append(tool_call) if tool_call[
"name"
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
# If no interrupts needed, return early
if not interrupt_tool_calls:
return None
# Process all tool calls that require interrupts
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
artificial_tool_messages: list[ToolMessage] = []
# Create action requests and review configs for all tools that need approval
# Create action requests and review configs for tools that need approval
action_requests: list[ActionRequest] = []
review_configs: list[ReviewConfig] = []
interrupt_indices: list[int] = []
for tool_call in interrupt_tool_calls:
config = self.interrupt_on[tool_call["name"]]
# Create ActionRequest and ReviewConfig using helper method
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime
)
action_requests.append(action_request)
review_configs.append(review_config)
interrupt_indices.append(idx)
# If no interrupts needed, return early
if not action_requests:
return None
# Create single HITLRequest with all actions and configs
hitl_request = HITLRequest(
@@ -325,29 +312,38 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
)
# Send interrupt and get response
hitl_response: HITLResponse = interrupt(hitl_request)
decisions = hitl_response["decisions"]
decisions = interrupt(hitl_request)["decisions"]
# Validate that the number of decisions matches the number of interrupt tool calls
if (decisions_len := len(decisions)) != (
interrupt_tool_calls_len := len(interrupt_tool_calls)
):
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
msg = (
f"Number of human decisions ({decisions_len}) does not match "
f"number of hanging tool calls ({interrupt_tool_calls_len})."
f"number of hanging tool calls ({interrupt_count})."
)
raise ValueError(msg)
# Process each decision using helper method
for i, decision in enumerate(decisions):
tool_call = interrupt_tool_calls[i]
config = self.interrupt_on[tool_call["name"]]
# Process decisions and rebuild tool calls in original order
revised_tool_calls: list[ToolCall] = []
artificial_tool_messages: list[ToolMessage] = []
decision_idx = 0
revised_tool_call, tool_message = self._process_decision(decision, tool_call, config)
if revised_tool_call:
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if idx in interrupt_indices:
# This was an interrupt tool call - process the decision
config = self.interrupt_on[tool_call["name"]]
decision = decisions[decision_idx]
decision_idx += 1
revised_tool_call, tool_message = self._process_decision(
decision, tool_call, config
)
if revised_tool_call is not None:
revised_tool_calls.append(revised_tool_call)
if tool_message:
artificial_tool_messages.append(tool_message)
else:
# This was auto-approved - keep original
revised_tool_calls.append(tool_call)
# Update the AI message to only include approved tool calls
last_ai_msg.tool_calls = revised_tool_calls

View File

@@ -573,3 +573,165 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
# Check string description
assert captured_request["action_requests"][1]["description"] == "Static description"
def test_human_in_the_loop_middleware_preserves_tool_call_order() -> None:
"""Test that middleware preserves the original order of tool calls.
This test verifies that when mixing auto-approved and interrupt tools,
the final tool call order matches the original order from the AI message.
"""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
# Create AI message with interleaved auto-approved and interrupt tools
# Order: auto (A) -> interrupt (B) -> auto (C) -> interrupt (D) -> auto (E)
ai_message = AIMessage(
content="Processing multiple tools",
tool_calls=[
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
],
)
state = {"messages": [HumanMessage(content="Test"), ai_message]}
def mock_approve_all(requests):
# Approve both interrupt tools (B and D)
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_approve_all
):
result = middleware.after_model(state, None)
assert result is not None
assert "messages" in result
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 5
# Verify original order is preserved: A -> B -> C -> D -> E
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
assert updated_ai_message.tool_calls[0]["id"] == "id_a"
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
assert updated_ai_message.tool_calls[1]["id"] == "id_b"
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
assert updated_ai_message.tool_calls[2]["id"] == "id_c"
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
assert updated_ai_message.tool_calls[3]["id"] == "id_d"
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
assert updated_ai_message.tool_calls[4]["id"] == "id_e"
def test_human_in_the_loop_middleware_preserves_order_with_edits() -> None:
"""Test that order is preserved when interrupt tools are edited."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="Processing multiple tools",
tool_calls=[
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
],
)
state = {"messages": [HumanMessage(content="Test"), ai_message]}
def mock_edit_responses(requests):
# Edit tool_b, approve tool_d
return {
"decisions": [
{
"type": "edit",
"edited_action": Action(name="tool_b", args={"val": 200}),
},
{"type": "approve"},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
):
result = middleware.after_model(state, None)
assert result is not None
updated_ai_message = result["messages"][0]
assert len(updated_ai_message.tool_calls) == 4
# Verify order: A (auto) -> B (edited) -> C (auto) -> D (approved)
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
assert updated_ai_message.tool_calls[0]["args"] == {"val": 1}
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
assert updated_ai_message.tool_calls[1]["args"] == {"val": 200} # Edited
assert updated_ai_message.tool_calls[1]["id"] == "id_b" # ID preserved
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
assert updated_ai_message.tool_calls[2]["args"] == {"val": 3}
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
assert updated_ai_message.tool_calls[3]["args"] == {"val": 4}
def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
"""Test that order is preserved when some interrupt tools are rejected."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
}
)
ai_message = AIMessage(
content="Processing multiple tools",
tool_calls=[
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
],
)
state = {"messages": [HumanMessage(content="Test"), ai_message]}
def mock_mixed_responses(requests):
# Reject tool_b, approve tool_d
return {
"decisions": [
{"type": "reject", "message": "Rejected tool B"},
{"type": "approve"},
]
}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
):
result = middleware.after_model(state, None)
assert result is not None
assert len(result["messages"]) == 2 # AI message + tool message for rejection
updated_ai_message = result["messages"][0]
# tool_b is still in the list (with rejection handled via tool message)
assert len(updated_ai_message.tool_calls) == 5
# Verify order maintained: A (auto) -> B (rejected) -> C (auto) -> D (approved) -> E (auto)
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
# Check rejection tool message
tool_message = result["messages"][1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "Rejected tool B"
assert tool_message.tool_call_id == "id_b"