update tests

This commit is contained in:
Sydney Runkle
2025-09-17 11:17:26 -04:00
parent ec3f0d1001
commit aad5086744
2 changed files with 17 additions and 16 deletions

View File

@@ -83,7 +83,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
self.tool_configs = resolved_tool_configs
self.action_request_prefix = action_request_prefix
def after_model(self, state: AgentState) -> dict[str, Any] | None: # noqa: PLR0915
def after_model(self, state: AgentState) -> dict[str, Any] | None:
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
messages = state["messages"]
if not messages:
@@ -94,16 +94,13 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
return None
# Separate tool calls that need interrupts from those that don't
interrupt_tool_calls: dict[str, ToolCall] = {}
interrupt_tool_calls: list[ToolCall] = []
auto_approved_tool_calls = []
for tool_call in last_ai_msg.tool_calls:
tool_name = tool_call["name"]
if tool_name in self.tool_configs:
# fix: id should not be typed as Optional on `langchain_core.messages.tool.ToolCall`
interrupt_tool_calls[tool_call["id"]] = tool_call # type: ignore[index]
else:
auto_approved_tool_calls.append(tool_call)
interrupt_tool_calls.append(tool_call) if tool_call[
"name"
] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
# If no interrupts needed, return early
if not interrupt_tool_calls:
@@ -116,8 +113,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
# Create interrupt requests for all tools that need approval
interrupt_requests: list[ActionRequest] = []
interrupt_tool_calls_list = list(interrupt_tool_calls.values())
for tool_call in interrupt_tool_calls_list:
for tool_call in interrupt_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
message = f"{self.action_request_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
@@ -125,7 +121,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
request: ActionRequest = {
"message": message,
"args": tool_args,
"args": {"tool_name": tool_name, "tool_args": tool_args},
"allowed_actions": allowed_actions,
"description": message,
}
@@ -134,15 +130,17 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
responses: list[ActionResponse] = interrupt(interrupt_requests)
# Validate that the number of responses matches the number of interrupt tool calls
if len(responses) != len(interrupt_tool_calls_list):
if (responses_len := len(responses)) != (
interrupt_tool_calls_len := len(interrupt_tool_calls)
):
msg = (
f"Number of human responses ({len(responses)}) does not match "
f"number of hanging tool calls ({len(interrupt_tool_calls_list)})."
f"Number of human responses ({responses_len}) does not match "
f"number of hanging tool calls ({interrupt_tool_calls_len})."
)
raise ValueError(msg)
for i, response in enumerate(responses):
tool_call = interrupt_tool_calls_list[i]
tool_call = interrupt_tool_calls[i]
allowed_actions = self.tool_configs[tool_call["name"]]
if response["action"] == "approve" and "approve" in allowed_actions:

View File

@@ -785,7 +785,10 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
assert "allowed_actions" in request
assert "description" in request
assert request["args"] == {"input": "test", "location": "SF"}
assert request["args"] == {
"tool_name": "test_tool",
"tool_args": {"input": "test", "location": "SF"},
}
assert request["allowed_actions"] == ["approve", "reject", "edit"]
assert "Custom prefix" in request["message"]
assert "Tool: test_tool" in request["message"]