diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 5a5a9287a7d..fd589d1c422 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -83,7 +83,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): self.tool_configs = resolved_tool_configs self.action_request_prefix = action_request_prefix - def after_model(self, state: AgentState) -> dict[str, Any] | None: # noqa: PLR0915 + def after_model(self, state: AgentState) -> dict[str, Any] | None: """Trigger HITL flows for relevant tool calls after an AIMessage.""" messages = state["messages"] if not messages: @@ -94,16 +94,13 @@ class HumanInTheLoopMiddleware(AgentMiddleware): return None # Separate tool calls that need interrupts from those that don't - interrupt_tool_calls: dict[str, ToolCall] = {} + interrupt_tool_calls: list[ToolCall] = [] auto_approved_tool_calls = [] for tool_call in last_ai_msg.tool_calls: - tool_name = tool_call["name"] - if tool_name in self.tool_configs: - # fix: id should not be typed as Optional on `langchain_core.messages.tool.ToolCall` - interrupt_tool_calls[tool_call["id"]] = tool_call # type: ignore[index] - else: - auto_approved_tool_calls.append(tool_call) + interrupt_tool_calls.append(tool_call) if tool_call[ + "name" + ] in self.tool_configs else auto_approved_tool_calls.append(tool_call) # If no interrupts needed, return early if not interrupt_tool_calls: @@ -116,8 +113,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): # Create interrupt requests for all tools that need approval interrupt_requests: list[ActionRequest] = [] - interrupt_tool_calls_list = list(interrupt_tool_calls.values()) - for tool_call in interrupt_tool_calls_list: + for tool_call in interrupt_tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"] message = f"{self.action_request_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" @@ -125,7 +121,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): request: ActionRequest = { "message": message, - "args": tool_args, + "args": {"tool_name": tool_name, "tool_args": tool_args}, "allowed_actions": allowed_actions, "description": message, } @@ -134,15 +130,17 @@ class HumanInTheLoopMiddleware(AgentMiddleware): responses: list[ActionResponse] = interrupt(interrupt_requests) # Validate that the number of responses matches the number of interrupt tool calls - if len(responses) != len(interrupt_tool_calls_list): + if (responses_len := len(responses)) != ( + interrupt_tool_calls_len := len(interrupt_tool_calls) + ): msg = ( - f"Number of human responses ({len(responses)}) does not match " - f"number of hanging tool calls ({len(interrupt_tool_calls_list)})." + f"Number of human responses ({responses_len}) does not match " + f"number of hanging tool calls ({interrupt_tool_calls_len})." ) raise ValueError(msg) for i, response in enumerate(responses): - tool_call = interrupt_tool_calls_list[i] + tool_call = interrupt_tool_calls[i] allowed_actions = self.tool_configs[tool_call["name"]] if response["action"] == "approve" and "approve" in allowed_actions: diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 32e2d3bd24d..42d712996db 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -785,7 +785,10 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None: assert "allowed_actions" in request assert "description" in request - assert request["args"] == {"input": "test", "location": "SF"} + assert request["args"] == { + "tool_name": "test_tool", + "tool_args": {"input": "test", "location": "SF"}, + } assert request["allowed_actions"] == ["approve", "reject", "edit"] assert "Custom prefix" in request["message"] assert "Tool: test_tool" in request["message"]