mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-28 16:35:41 +00:00
fix: don't reorder tool calls in HITL middleware (#34023)
This commit is contained in:
@@ -287,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls: list[ToolCall] = []
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_ai_msg.tool_calls:
|
||||
interrupt_tool_calls.append(tool_call) if tool_call[
|
||||
"name"
|
||||
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
return None
|
||||
|
||||
# Process all tool calls that require interrupts
|
||||
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
|
||||
# Create action requests and review configs for all tools that need approval
|
||||
# Create action requests and review configs for tools that need approval
|
||||
action_requests: list[ActionRequest] = []
|
||||
review_configs: list[ReviewConfig] = []
|
||||
interrupt_indices: list[int] = []
|
||||
|
||||
for tool_call in interrupt_tool_calls:
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
|
||||
# Create ActionRequest and ReviewConfig using helper method
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||
action_request, review_config = self._create_action_and_config(
|
||||
tool_call, config, state, runtime
|
||||
)
|
||||
action_requests.append(action_request)
|
||||
review_configs.append(review_config)
|
||||
interrupt_indices.append(idx)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not action_requests:
|
||||
return None
|
||||
|
||||
# Create single HITLRequest with all actions and configs
|
||||
hitl_request = HITLRequest(
|
||||
@@ -325,29 +312,38 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
|
||||
# Send interrupt and get response
|
||||
hitl_response: HITLResponse = interrupt(hitl_request)
|
||||
decisions = hitl_response["decisions"]
|
||||
decisions = interrupt(hitl_request)["decisions"]
|
||||
|
||||
# Validate that the number of decisions matches the number of interrupt tool calls
|
||||
if (decisions_len := len(decisions)) != (
|
||||
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
||||
):
|
||||
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
||||
msg = (
|
||||
f"Number of human decisions ({decisions_len}) does not match "
|
||||
f"number of hanging tool calls ({interrupt_tool_calls_len})."
|
||||
f"number of hanging tool calls ({interrupt_count})."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process each decision using helper method
|
||||
for i, decision in enumerate(decisions):
|
||||
tool_call = interrupt_tool_calls[i]
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
# Process decisions and rebuild tool calls in original order
|
||||
revised_tool_calls: list[ToolCall] = []
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
decision_idx = 0
|
||||
|
||||
revised_tool_call, tool_message = self._process_decision(decision, tool_call, config)
|
||||
if revised_tool_call:
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if idx in interrupt_indices:
|
||||
# This was an interrupt tool call - process the decision
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
decision = decisions[decision_idx]
|
||||
decision_idx += 1
|
||||
|
||||
revised_tool_call, tool_message = self._process_decision(
|
||||
decision, tool_call, config
|
||||
)
|
||||
if revised_tool_call is not None:
|
||||
revised_tool_calls.append(revised_tool_call)
|
||||
if tool_message:
|
||||
artificial_tool_messages.append(tool_message)
|
||||
else:
|
||||
# This was auto-approved - keep original
|
||||
revised_tool_calls.append(tool_call)
|
||||
|
||||
# Update the AI message to only include approved tool calls
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
@@ -573,3 +573,165 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
|
||||
# Check string description
|
||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_preserves_tool_call_order() -> None:
|
||||
"""Test that middleware preserves the original order of tool calls.
|
||||
|
||||
This test verifies that when mixing auto-approved and interrupt tools,
|
||||
the final tool call order matches the original order from the AI message.
|
||||
"""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
# Create AI message with interleaved auto-approved and interrupt tools
|
||||
# Order: auto (A) -> interrupt (B) -> auto (C) -> interrupt (D) -> auto (E)
|
||||
ai_message = AIMessage(
|
||||
content="Processing multiple tools",
|
||||
tool_calls=[
|
||||
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
|
||||
def mock_approve_all(requests):
|
||||
# Approve both interrupt tools (B and D)
|
||||
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_approve_all
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 5
|
||||
|
||||
# Verify original order is preserved: A -> B -> C -> D -> E
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "id_a"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "id_b"
|
||||
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||
assert updated_ai_message.tool_calls[2]["id"] == "id_c"
|
||||
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||
assert updated_ai_message.tool_calls[3]["id"] == "id_d"
|
||||
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
|
||||
assert updated_ai_message.tool_calls[4]["id"] == "id_e"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_preserves_order_with_edits() -> None:
|
||||
"""Test that order is preserved when interrupt tools are edited."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="Processing multiple tools",
|
||||
tool_calls=[
|
||||
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
# Edit tool_b, approve tool_d
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
"type": "edit",
|
||||
"edited_action": Action(name="tool_b", args={"val": 200}),
|
||||
},
|
||||
{"type": "approve"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 4
|
||||
|
||||
# Verify order: A (auto) -> B (edited) -> C (auto) -> D (approved)
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"val": 1}
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||
assert updated_ai_message.tool_calls[1]["args"] == {"val": 200} # Edited
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "id_b" # ID preserved
|
||||
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||
assert updated_ai_message.tool_calls[2]["args"] == {"val": 3}
|
||||
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||
assert updated_ai_message.tool_calls[3]["args"] == {"val": 4}
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
||||
"""Test that order is preserved when some interrupt tools are rejected."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="Processing multiple tools",
|
||||
tool_calls=[
|
||||
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
# Reject tool_b, approve tool_d
|
||||
return {
|
||||
"decisions": [
|
||||
{"type": "reject", "message": "Rejected tool B"},
|
||||
{"type": "approve"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
assert result is not None
|
||||
assert len(result["messages"]) == 2 # AI message + tool message for rejection
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
# tool_b is still in the list (with rejection handled via tool message)
|
||||
assert len(updated_ai_message.tool_calls) == 5
|
||||
|
||||
# Verify order maintained: A (auto) -> B (rejected) -> C (auto) -> D (approved) -> E (auto)
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
|
||||
|
||||
# Check rejection tool message
|
||||
tool_message = result["messages"][1]
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "Rejected tool B"
|
||||
assert tool_message.tool_call_id == "id_b"
|
||||
|
||||
Reference in New Issue
Block a user