mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-02 07:10:41 +00:00
test w/ new boolean logic
This commit is contained in:
@@ -111,19 +111,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
"""Initialize the human in the loop middleware.
|
||||
|
||||
Args:
|
||||
tool_configs: Mapping of tool name to interrupt config.
|
||||
tool_configs: Mapping of tool name to interrupt config (bool | HumanInterruptConfig).
|
||||
* `True` indicates all interrupt config options are allowed.
|
||||
* `False` indicates that the tool is auto-approved.
|
||||
* `HumanInterruptConfig` indicates the specific interrupt config options to use.
|
||||
message_prefix: The message prefix to use when constructing interrupt content.
|
||||
"""
|
||||
super().__init__()
|
||||
resolved_tool_configs = {}
|
||||
for tool_name, tool_config in tool_configs.items():
|
||||
if isinstance(tool_config, bool):
|
||||
resolved_tool_configs[tool_name] = HumanInterruptConfig(
|
||||
allow_approve=tool_config,
|
||||
allow_ignore=tool_config,
|
||||
allow_response=tool_config,
|
||||
allow_edit=tool_config,
|
||||
)
|
||||
if tool_config is True:
|
||||
resolved_tool_configs[tool_name] = HumanInterruptConfig(
|
||||
allow_approve=True,
|
||||
allow_ignore=True,
|
||||
allow_response=True,
|
||||
allow_edit=True,
|
||||
)
|
||||
else:
|
||||
resolved_tool_configs[tool_name] = tool_config
|
||||
self.tool_configs = resolved_tool_configs
|
||||
|
||||
@@ -906,6 +906,112 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
"""Test HITL middleware with boolean tool configs."""
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test approve
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "approve", "tool_call_id": "1"}],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[
|
||||
{
|
||||
"type": "edit",
|
||||
"tool_call_id": "1",
|
||||
"action": "test_tool",
|
||||
"args": {"input": "edited"},
|
||||
}
|
||||
],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
# Test ignore
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "ignore", "tool_call_id": "1"}],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "jump_to" in result
|
||||
assert result["jump_to"] == "model"
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], ToolMessage)
|
||||
assert (
|
||||
"User ignored the tool call for `test_tool` with id 1" in result["messages"][0].content
|
||||
)
|
||||
|
||||
# Test response
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "response", "tool_call_id": "1", "tool_message": "Custom response"}],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "jump_to" in result
|
||||
assert result["jump_to"] == "model"
|
||||
assert len(result["messages"]) == 1
|
||||
assert isinstance(result["messages"][0], ToolMessage)
|
||||
assert result["messages"][0].content == "Custom response"
|
||||
assert result["messages"][0].status == "error"
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state)
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_missing_tool_call_id() -> None:
|
||||
"""Test that missing tool call ID in resume raises an error."""
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "approve"}], # Missing tool_call_id
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: \{'type': 'approve'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "approve", "tool_call_id": "nonexistent"}],
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: \{'type': 'approve', 'tool_call_id': 'nonexistent'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
|
||||
# Tests for AnthropicPromptCachingMiddleware
|
||||
def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware initialization."""
|
||||
|
||||
Reference in New Issue
Block a user