mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-29 00:53:39 +00:00
fix: don't reorder tool calls in HITL middleware (#34023)
This commit is contained in:
@@ -287,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|||||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Separate tool calls that need interrupts from those that don't
|
# Create action requests and review configs for tools that need approval
|
||||||
interrupt_tool_calls: list[ToolCall] = []
|
|
||||||
auto_approved_tool_calls = []
|
|
||||||
|
|
||||||
for tool_call in last_ai_msg.tool_calls:
|
|
||||||
interrupt_tool_calls.append(tool_call) if tool_call[
|
|
||||||
"name"
|
|
||||||
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
# If no interrupts needed, return early
|
|
||||||
if not interrupt_tool_calls:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Process all tool calls that require interrupts
|
|
||||||
revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
|
||||||
artificial_tool_messages: list[ToolMessage] = []
|
|
||||||
|
|
||||||
# Create action requests and review configs for all tools that need approval
|
|
||||||
action_requests: list[ActionRequest] = []
|
action_requests: list[ActionRequest] = []
|
||||||
review_configs: list[ReviewConfig] = []
|
review_configs: list[ReviewConfig] = []
|
||||||
|
interrupt_indices: list[int] = []
|
||||||
|
|
||||||
for tool_call in interrupt_tool_calls:
|
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||||
config = self.interrupt_on[tool_call["name"]]
|
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||||
|
|
||||||
# Create ActionRequest and ReviewConfig using helper method
|
|
||||||
action_request, review_config = self._create_action_and_config(
|
action_request, review_config = self._create_action_and_config(
|
||||||
tool_call, config, state, runtime
|
tool_call, config, state, runtime
|
||||||
)
|
)
|
||||||
action_requests.append(action_request)
|
action_requests.append(action_request)
|
||||||
review_configs.append(review_config)
|
review_configs.append(review_config)
|
||||||
|
interrupt_indices.append(idx)
|
||||||
|
|
||||||
|
# If no interrupts needed, return early
|
||||||
|
if not action_requests:
|
||||||
|
return None
|
||||||
|
|
||||||
# Create single HITLRequest with all actions and configs
|
# Create single HITLRequest with all actions and configs
|
||||||
hitl_request = HITLRequest(
|
hitl_request = HITLRequest(
|
||||||
@@ -325,29 +312,38 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Send interrupt and get response
|
# Send interrupt and get response
|
||||||
hitl_response: HITLResponse = interrupt(hitl_request)
|
decisions = interrupt(hitl_request)["decisions"]
|
||||||
decisions = hitl_response["decisions"]
|
|
||||||
|
|
||||||
# Validate that the number of decisions matches the number of interrupt tool calls
|
# Validate that the number of decisions matches the number of interrupt tool calls
|
||||||
if (decisions_len := len(decisions)) != (
|
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
||||||
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
|
||||||
):
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Number of human decisions ({decisions_len}) does not match "
|
f"Number of human decisions ({decisions_len}) does not match "
|
||||||
f"number of hanging tool calls ({interrupt_tool_calls_len})."
|
f"number of hanging tool calls ({interrupt_count})."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Process each decision using helper method
|
# Process decisions and rebuild tool calls in original order
|
||||||
for i, decision in enumerate(decisions):
|
revised_tool_calls: list[ToolCall] = []
|
||||||
tool_call = interrupt_tool_calls[i]
|
artificial_tool_messages: list[ToolMessage] = []
|
||||||
config = self.interrupt_on[tool_call["name"]]
|
decision_idx = 0
|
||||||
|
|
||||||
revised_tool_call, tool_message = self._process_decision(decision, tool_call, config)
|
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||||
if revised_tool_call:
|
if idx in interrupt_indices:
|
||||||
|
# This was an interrupt tool call - process the decision
|
||||||
|
config = self.interrupt_on[tool_call["name"]]
|
||||||
|
decision = decisions[decision_idx]
|
||||||
|
decision_idx += 1
|
||||||
|
|
||||||
|
revised_tool_call, tool_message = self._process_decision(
|
||||||
|
decision, tool_call, config
|
||||||
|
)
|
||||||
|
if revised_tool_call is not None:
|
||||||
revised_tool_calls.append(revised_tool_call)
|
revised_tool_calls.append(revised_tool_call)
|
||||||
if tool_message:
|
if tool_message:
|
||||||
artificial_tool_messages.append(tool_message)
|
artificial_tool_messages.append(tool_message)
|
||||||
|
else:
|
||||||
|
# This was auto-approved - keep original
|
||||||
|
revised_tool_calls.append(tool_call)
|
||||||
|
|
||||||
# Update the AI message to only include approved tool calls
|
# Update the AI message to only include approved tool calls
|
||||||
last_ai_msg.tool_calls = revised_tool_calls
|
last_ai_msg.tool_calls = revised_tool_calls
|
||||||
|
|||||||
@@ -573,3 +573,165 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
|||||||
|
|
||||||
# Check string description
|
# Check string description
|
||||||
assert captured_request["action_requests"][1]["description"] == "Static description"
|
assert captured_request["action_requests"][1]["description"] == "Static description"
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_preserves_tool_call_order() -> None:
|
||||||
|
"""Test that middleware preserves the original order of tool calls.
|
||||||
|
|
||||||
|
This test verifies that when mixing auto-approved and interrupt tools,
|
||||||
|
the final tool call order matches the original order from the AI message.
|
||||||
|
"""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create AI message with interleaved auto-approved and interrupt tools
|
||||||
|
# Order: auto (A) -> interrupt (B) -> auto (C) -> interrupt (D) -> auto (E)
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="Processing multiple tools",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||||
|
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||||
|
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||||
|
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||||
|
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||||
|
|
||||||
|
def mock_approve_all(requests):
|
||||||
|
# Approve both interrupt tools (B and D)
|
||||||
|
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_approve_all
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state, None)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
assert len(updated_ai_message.tool_calls) == 5
|
||||||
|
|
||||||
|
# Verify original order is preserved: A -> B -> C -> D -> E
|
||||||
|
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||||
|
assert updated_ai_message.tool_calls[0]["id"] == "id_a"
|
||||||
|
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||||
|
assert updated_ai_message.tool_calls[1]["id"] == "id_b"
|
||||||
|
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||||
|
assert updated_ai_message.tool_calls[2]["id"] == "id_c"
|
||||||
|
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||||
|
assert updated_ai_message.tool_calls[3]["id"] == "id_d"
|
||||||
|
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
|
||||||
|
assert updated_ai_message.tool_calls[4]["id"] == "id_e"
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_preserves_order_with_edits() -> None:
|
||||||
|
"""Test that order is preserved when interrupt tools are edited."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="Processing multiple tools",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||||
|
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||||
|
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||||
|
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||||
|
|
||||||
|
def mock_edit_responses(requests):
|
||||||
|
# Edit tool_b, approve tool_d
|
||||||
|
return {
|
||||||
|
"decisions": [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"edited_action": Action(name="tool_b", args={"val": 200}),
|
||||||
|
},
|
||||||
|
{"type": "approve"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state, None)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
assert len(updated_ai_message.tool_calls) == 4
|
||||||
|
|
||||||
|
# Verify order: A (auto) -> B (edited) -> C (auto) -> D (approved)
|
||||||
|
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||||
|
assert updated_ai_message.tool_calls[0]["args"] == {"val": 1}
|
||||||
|
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||||
|
assert updated_ai_message.tool_calls[1]["args"] == {"val": 200} # Edited
|
||||||
|
assert updated_ai_message.tool_calls[1]["id"] == "id_b" # ID preserved
|
||||||
|
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||||
|
assert updated_ai_message.tool_calls[2]["args"] == {"val": 3}
|
||||||
|
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||||
|
assert updated_ai_message.tool_calls[3]["args"] == {"val": 4}
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
||||||
|
"""Test that order is preserved when some interrupt tools are rejected."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
interrupt_on={
|
||||||
|
"tool_b": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
"tool_d": {"allowed_decisions": ["approve", "edit", "reject"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="Processing multiple tools",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "tool_a", "args": {"val": 1}, "id": "id_a"},
|
||||||
|
{"name": "tool_b", "args": {"val": 2}, "id": "id_b"},
|
||||||
|
{"name": "tool_c", "args": {"val": 3}, "id": "id_c"},
|
||||||
|
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||||
|
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||||
|
|
||||||
|
def mock_mixed_responses(requests):
|
||||||
|
# Reject tool_b, approve tool_d
|
||||||
|
return {
|
||||||
|
"decisions": [
|
||||||
|
{"type": "reject", "message": "Rejected tool B"},
|
||||||
|
{"type": "approve"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state, None)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["messages"]) == 2 # AI message + tool message for rejection
|
||||||
|
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
# tool_b is still in the list (with rejection handled via tool message)
|
||||||
|
assert len(updated_ai_message.tool_calls) == 5
|
||||||
|
|
||||||
|
# Verify order maintained: A (auto) -> B (rejected) -> C (auto) -> D (approved) -> E (auto)
|
||||||
|
assert updated_ai_message.tool_calls[0]["name"] == "tool_a"
|
||||||
|
assert updated_ai_message.tool_calls[1]["name"] == "tool_b"
|
||||||
|
assert updated_ai_message.tool_calls[2]["name"] == "tool_c"
|
||||||
|
assert updated_ai_message.tool_calls[3]["name"] == "tool_d"
|
||||||
|
assert updated_ai_message.tool_calls[4]["name"] == "tool_e"
|
||||||
|
|
||||||
|
# Check rejection tool message
|
||||||
|
tool_message = result["messages"][1]
|
||||||
|
assert isinstance(tool_message, ToolMessage)
|
||||||
|
assert tool_message.content == "Rejected tool B"
|
||||||
|
assert tool_message.tool_call_id == "id_b"
|
||||||
|
|||||||
Reference in New Issue
Block a user