mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 07:41:03 +00:00
update tests
This commit is contained in:
@@ -83,7 +83,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
self.tool_configs = resolved_tool_configs
|
||||
self.action_request_prefix = action_request_prefix
|
||||
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None: # noqa: PLR0915
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
@@ -94,16 +94,13 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls: dict[str, ToolCall] = {}
|
||||
interrupt_tool_calls: list[ToolCall] = []
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_ai_msg.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
if tool_name in self.tool_configs:
|
||||
# fix: id should not be typed as Optional on `langchain_core.messages.tool.ToolCall`
|
||||
interrupt_tool_calls[tool_call["id"]] = tool_call # type: ignore[index]
|
||||
else:
|
||||
auto_approved_tool_calls.append(tool_call)
|
||||
interrupt_tool_calls.append(tool_call) if tool_call[
|
||||
"name"
|
||||
] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
@@ -116,8 +113,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
# Create interrupt requests for all tools that need approval
|
||||
interrupt_requests: list[ActionRequest] = []
|
||||
interrupt_tool_calls_list = list(interrupt_tool_calls.values())
|
||||
for tool_call in interrupt_tool_calls_list:
|
||||
for tool_call in interrupt_tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
message = f"{self.action_request_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
@@ -125,7 +121,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
request: ActionRequest = {
|
||||
"message": message,
|
||||
"args": tool_args,
|
||||
"args": {"tool_name": tool_name, "tool_args": tool_args},
|
||||
"allowed_actions": allowed_actions,
|
||||
"description": message,
|
||||
}
|
||||
@@ -134,15 +130,17 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
responses: list[ActionResponse] = interrupt(interrupt_requests)
|
||||
|
||||
# Validate that the number of responses matches the number of interrupt tool calls
|
||||
if len(responses) != len(interrupt_tool_calls_list):
|
||||
if (responses_len := len(responses)) != (
|
||||
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
||||
):
|
||||
msg = (
|
||||
f"Number of human responses ({len(responses)}) does not match "
|
||||
f"number of hanging tool calls ({len(interrupt_tool_calls_list)})."
|
||||
f"Number of human responses ({responses_len}) does not match "
|
||||
f"number of hanging tool calls ({interrupt_tool_calls_len})."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
tool_call = interrupt_tool_calls_list[i]
|
||||
tool_call = interrupt_tool_calls[i]
|
||||
allowed_actions = self.tool_configs[tool_call["name"]]
|
||||
|
||||
if response["action"] == "approve" and "approve" in allowed_actions:
|
||||
|
||||
@@ -785,7 +785,10 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
assert "allowed_actions" in request
|
||||
assert "description" in request
|
||||
|
||||
assert request["args"] == {"input": "test", "location": "SF"}
|
||||
assert request["args"] == {
|
||||
"tool_name": "test_tool",
|
||||
"tool_args": {"input": "test", "location": "SF"},
|
||||
}
|
||||
assert request["allowed_actions"] == ["approve", "reject", "edit"]
|
||||
assert "Custom prefix" in request["message"]
|
||||
assert "Tool: test_tool" in request["message"]
|
||||
|
||||
Reference in New Issue
Block a user