test w/ new boolean logic

This commit is contained in:
Sydney Runkle
2025-09-16 10:36:08 -04:00
parent 5470d7621c
commit e8802c60de
2 changed files with 117 additions and 7 deletions

View File

@@ -111,19 +111,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
"""Initialize the human in the loop middleware.
Args:
tool_configs: Mapping of tool name to interrupt config.
tool_configs: Mapping of tool name to interrupt config (bool | HumanInterruptConfig).
* `True` indicates all interrupt config options are allowed.
* `False` indicates that the tool is auto-approved.
* `HumanInterruptConfig` indicates the specific interrupt config options to use.
message_prefix: The message prefix to use when constructing interrupt content.
"""
super().__init__()
resolved_tool_configs = {}
for tool_name, tool_config in tool_configs.items():
if isinstance(tool_config, bool):
resolved_tool_configs[tool_name] = HumanInterruptConfig(
allow_approve=tool_config,
allow_ignore=tool_config,
allow_response=tool_config,
allow_edit=tool_config,
)
if tool_config is True:
resolved_tool_configs[tool_name] = HumanInterruptConfig(
allow_approve=True,
allow_ignore=True,
allow_response=True,
allow_edit=True,
)
else:
resolved_tool_configs[tool_name] = tool_config
self.tool_configs = resolved_tool_configs

View File

@@ -906,6 +906,112 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
def test_human_in_the_loop_middleware_boolean_configs() -> None:
"""Test HITL middleware with boolean tool configs."""
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test approve
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve", "tool_call_id": "1"}],
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls == ai_message.tool_calls
# Test edit
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[
{
"type": "edit",
"tool_call_id": "1",
"action": "test_tool",
"args": {"input": "edited"},
}
],
):
result = middleware.after_model(state)
assert result is not None
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
# Test ignore
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "ignore", "tool_call_id": "1"}],
):
result = middleware.after_model(state)
assert result is not None
assert "jump_to" in result
assert result["jump_to"] == "model"
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], ToolMessage)
assert (
"User ignored the tool call for `test_tool` with id 1" in result["messages"][0].content
)
# Test response
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "response", "tool_call_id": "1", "tool_message": "Custom response"}],
):
result = middleware.after_model(state)
assert result is not None
assert "jump_to" in result
assert result["jump_to"] == "model"
assert len(result["messages"]) == 1
assert isinstance(result["messages"][0], ToolMessage)
assert result["messages"][0].content == "Custom response"
assert result["messages"][0].status == "error"
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False})
result = middleware.after_model(state)
# No interruption should occur
assert result is None
def test_human_in_the_loop_middleware_missing_tool_call_id() -> None:
"""Test that missing tool call ID in resume raises an error."""
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
ai_message = AIMessage(
content="I'll help you",
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve"}], # Missing tool_call_id
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: \{'type': 'approve'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
):
middleware.after_model(state)
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve", "tool_call_id": "nonexistent"}],
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: \{'type': 'approve', 'tool_call_id': 'nonexistent'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
):
middleware.after_model(state)
# Tests for AnthropicPromptCachingMiddleware
def test_anthropic_prompt_caching_middleware_initialization() -> None:
"""Test AnthropicPromptCachingMiddleware initialization."""