mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain): fix types in test_file_search and test_human_in_the_loop (#34602)
This commit is contained in:
committed by
GitHub
parent
be2c7f1aa8
commit
901690ceec
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
from langchain.agents.middleware.file_search import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
@@ -22,6 +23,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="print", include="*.{py")
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
@@ -63,6 +66,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="hello")
|
||||
|
||||
assert "/file1.py" in result
|
||||
@@ -76,6 +81,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="hello", include="*.py")
|
||||
|
||||
assert "/file1.py" in result
|
||||
@@ -87,6 +94,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="hello", output_mode="content")
|
||||
|
||||
assert "/test.py:2:hello" in result
|
||||
@@ -97,6 +106,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="hello", output_mode="count")
|
||||
|
||||
assert "/test.py:2" in result
|
||||
@@ -107,6 +118,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="[invalid")
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
@@ -117,6 +130,8 @@ class TestFilesystemGrepSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="notfound")
|
||||
|
||||
assert result == "No matches found"
|
||||
@@ -133,6 +148,8 @@ class TestFilesystemGlobSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.py")
|
||||
|
||||
assert "/file1.py" in result
|
||||
@@ -148,6 +165,8 @@ class TestFilesystemGlobSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="**/*.py")
|
||||
|
||||
assert "/src/test.py" in result
|
||||
@@ -162,6 +181,8 @@ class TestFilesystemGlobSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.py", path="/src")
|
||||
|
||||
assert "/src/file1.py" in result
|
||||
@@ -173,6 +194,8 @@ class TestFilesystemGlobSearch:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.py")
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -181,6 +204,8 @@ class TestFilesystemGlobSearch:
|
||||
"""Test glob search with non-existent path."""
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.py", path="/nonexistent")
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -201,6 +226,8 @@ class TestPathTraversalSecurity:
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed"))
|
||||
|
||||
# Try to escape with ..
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.txt", path="/../")
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -216,6 +243,8 @@ class TestPathTraversalSecurity:
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed"))
|
||||
|
||||
# Try to access with absolute path
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.txt", path=str(tmp_path))
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -234,6 +263,8 @@ class TestPathTraversalSecurity:
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path / "allowed"))
|
||||
|
||||
# Try to access via symlink
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.txt", path="/link")
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -242,6 +273,8 @@ class TestPathTraversalSecurity:
|
||||
"""Test that tilde paths are handled safely."""
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path))
|
||||
|
||||
assert isinstance(middleware.glob_search, StructuredTool)
|
||||
assert middleware.glob_search.func is not None
|
||||
result = middleware.glob_search.func(pattern="*.txt", path="~/")
|
||||
|
||||
assert result == "No files found"
|
||||
@@ -256,6 +289,8 @@ class TestPathTraversalSecurity:
|
||||
)
|
||||
|
||||
# Try to search outside root
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="secret", path="/../")
|
||||
|
||||
assert result == "No matches found"
|
||||
@@ -331,6 +366,8 @@ class TestGrepEdgeCases:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="def.*:")
|
||||
|
||||
assert "/test.py" in result
|
||||
@@ -341,6 +378,8 @@ class TestGrepEdgeCases:
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="(?i)hello")
|
||||
|
||||
assert "/test.py" in result
|
||||
@@ -358,6 +397,8 @@ class TestGrepEdgeCases:
|
||||
max_file_size_mb=1, # 1MB limit
|
||||
)
|
||||
|
||||
assert isinstance(middleware.grep_search, StructuredTool)
|
||||
assert middleware.grep_search.func is not None
|
||||
result = middleware.grep_search.func(pattern="x")
|
||||
|
||||
# Large file should be skipped
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware import InterruptOnConfig
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
Action,
|
||||
HumanInTheLoopMiddleware,
|
||||
@@ -33,13 +34,14 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
result = middleware.after_model(state, None)
|
||||
state = AgentState[Any](messages=[])
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
# Test with message but no tool calls
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi there")]}
|
||||
result = middleware.after_model(state, None)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), AIMessage(content="Hi there")])
|
||||
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
# Test with tool calls that don't require interrupts
|
||||
@@ -47,8 +49,8 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "other_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
result = middleware.after_model(state, None)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -62,13 +64,13 @@ def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_accept(requests):
|
||||
def mock_accept(_: Any) -> dict[str, Any]:
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -80,7 +82,7 @@ def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
)
|
||||
state["messages"].append(AIMessage(content="test_tool called with result: Tool message"))
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
# No interrupts needed
|
||||
assert result is None
|
||||
|
||||
@@ -95,9 +97,9 @@ def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_edit(requests):
|
||||
def mock_edit(_: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
@@ -111,7 +113,7 @@ def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -129,15 +131,15 @@ def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_response(requests):
|
||||
def mock_response(_: Any) -> dict[str, Any]:
|
||||
return {"decisions": [{"type": "reject", "message": "Custom response message"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 2
|
||||
@@ -164,9 +166,9 @@ def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="What's the weather?"), ai_message])
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
def mock_mixed_responses(_: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"decisions": [
|
||||
{"type": "approve"},
|
||||
@@ -177,7 +179,7 @@ def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert (
|
||||
@@ -213,9 +215,9 @@ def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="What's the weather?"), ai_message])
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
def mock_edit_responses(_: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
@@ -238,7 +240,7 @@ def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -260,9 +262,9 @@ def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_edit_with_args(requests):
|
||||
def mock_edit_with_args(_: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
@@ -279,7 +281,7 @@ def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_edit_with_args,
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -300,9 +302,9 @@ def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_unknown(requests):
|
||||
def mock_unknown(_: Any) -> dict[str, Any]:
|
||||
return {"decisions": [{"type": "unknown"}]}
|
||||
|
||||
with (
|
||||
@@ -317,7 +319,7 @@ def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
),
|
||||
),
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
@@ -331,9 +333,9 @@ def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_disallowed_action(requests):
|
||||
def mock_disallowed_action(_: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"decisions": [
|
||||
{
|
||||
@@ -362,7 +364,7 @@ def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
),
|
||||
),
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||
@@ -378,13 +380,13 @@ def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> Non
|
||||
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_accept(requests):
|
||||
def mock_accept(_: Any) -> dict[str, Any]:
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -407,11 +409,11 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
def mock_capture_requests(request: Any) -> dict[str, Any]:
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}]}
|
||||
@@ -419,7 +421,7 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
@@ -447,14 +449,14 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
# Test accept
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value={"decisions": [{"type": "approve"}]},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -475,7 +477,7 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
]
|
||||
},
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
@@ -483,7 +485,7 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
@@ -496,7 +498,7 @@ def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
# Test with too few responses
|
||||
with (
|
||||
@@ -511,7 +513,7 @@ def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
),
|
||||
),
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
# Test with too many responses
|
||||
with (
|
||||
@@ -531,26 +533,28 @@ def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
),
|
||||
),
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
"""Test that description field accepts both string and callable."""
|
||||
|
||||
def custom_description(tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
||||
def custom_description(
|
||||
tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[None]
|
||||
) -> str:
|
||||
"""Generate a custom description."""
|
||||
return f"Custom: {tool_call['name']} with args {tool_call['args']}"
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
interrupt_on={
|
||||
"tool_with_callable": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": custom_description,
|
||||
},
|
||||
"tool_with_string": {
|
||||
"allowed_decisions": ["approve"],
|
||||
"description": "Static description",
|
||||
},
|
||||
"tool_with_callable": InterruptOnConfig(
|
||||
allowed_decisions=["approve"],
|
||||
description=custom_description,
|
||||
),
|
||||
"tool_with_string": InterruptOnConfig(
|
||||
allowed_decisions=["approve"],
|
||||
description="Static description",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -561,11 +565,11 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
{"name": "tool_with_string", "args": {"y": 2}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_capture_requests(request):
|
||||
def mock_capture_requests(request: Any) -> dict[str, Any]:
|
||||
nonlocal captured_request
|
||||
captured_request = request
|
||||
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||
@@ -573,7 +577,7 @@ def test_human_in_the_loop_middleware_description_as_callable() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state, None)
|
||||
middleware.after_model(state, Runtime())
|
||||
|
||||
assert captured_request is not None
|
||||
assert "action_requests" in captured_request
|
||||
@@ -614,16 +618,16 @@ def test_human_in_the_loop_middleware_preserves_tool_call_order() -> None:
|
||||
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_approve_all(requests):
|
||||
def mock_approve_all(_: Any) -> dict[str, Any]:
|
||||
# Approve both interrupt tools (B and D)
|
||||
return {"decisions": [{"type": "approve"}, {"type": "approve"}]}
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_approve_all
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
|
||||
@@ -661,9 +665,9 @@ def test_human_in_the_loop_middleware_preserves_order_with_edits() -> None:
|
||||
{"name": "tool_d", "args": {"val": 4}, "id": "id_d"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
def mock_edit_responses(_: Any) -> dict[str, Any]:
|
||||
# Edit tool_b, approve tool_d
|
||||
return {
|
||||
"decisions": [
|
||||
@@ -678,7 +682,7 @@ def test_human_in_the_loop_middleware_preserves_order_with_edits() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
@@ -715,9 +719,9 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
||||
{"name": "tool_e", "args": {"val": 5}, "id": "id_e"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Test"), ai_message]}
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), ai_message])
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
def mock_mixed_responses(_: Any) -> dict[str, Any]:
|
||||
# Reject tool_b, approve tool_d
|
||||
return {
|
||||
"decisions": [
|
||||
@@ -729,7 +733,7 @@ def test_human_in_the_loop_middleware_preserves_order_with_rejections() -> None:
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state, None)
|
||||
result = middleware.after_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert len(result["messages"]) == 2 # AI message + tool message for rejection
|
||||
|
||||
|
||||
Reference in New Issue
Block a user