diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_file_search.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_file_search.py index 40aeedd71bd..bb3340d9818 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_file_search.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_file_search.py @@ -7,6 +7,9 @@ import pytest from langchain.agents.middleware.file_search import ( FilesystemFileSearchMiddleware, + _expand_include_patterns, + _is_valid_include_pattern, + _match_include_pattern, ) @@ -259,3 +262,105 @@ class TestPathTraversalSecurity: assert result == "No matches found" assert "secret" not in result + + +class TestExpandIncludePatterns: + """Tests for _expand_include_patterns helper function.""" + + def test_expand_patterns_basic_brace_expansion(self) -> None: + """Test basic brace expansion with multiple options.""" + result = _expand_include_patterns("*.{py,txt}") + assert result == ["*.py", "*.txt"] + + def test_expand_patterns_nested_braces(self) -> None: + """Test nested brace expansion.""" + result = _expand_include_patterns("test.{a,b}.{c,d}") + assert result is not None + assert len(result) == 4 + assert "test.a.c" in result + assert "test.b.d" in result + + @pytest.mark.parametrize( + "pattern", + [ + "*.py}", # closing brace without opening + "*.{}", # empty braces + "*.{py", # unclosed brace + ], + ) + def test_expand_patterns_invalid_braces(self, pattern: str) -> None: + """Test patterns with invalid brace syntax return None.""" + result = _expand_include_patterns(pattern) + assert result is None + + +class TestValidateIncludePattern: + """Tests for _is_valid_include_pattern helper function.""" + + @pytest.mark.parametrize( + "pattern", + [ + "", # empty pattern + "*.py\x00", # null byte + "*.py\n", # newline + ], + ) + def test_validate_invalid_patterns(self, pattern: str) -> None: + """Test that invalid patterns are rejected.""" + assert not _is_valid_include_pattern(pattern) + + +class TestMatchIncludePattern: + """Tests for _match_include_pattern helper function.""" + + def test_match_pattern_with_braces(self) -> None: + """Test matching with brace expansion.""" + assert _match_include_pattern("test.py", "*.{py,txt}") + assert _match_include_pattern("test.txt", "*.{py,txt}") + assert not _match_include_pattern("test.md", "*.{py,txt}") + + def test_match_pattern_invalid_expansion(self) -> None: + """Test matching with pattern that cannot be expanded returns False.""" + assert not _match_include_pattern("test.py", "*.{}") + + +class TestGrepEdgeCases: + """Tests for edge cases in grep search.""" + + def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None: + """Test grep with special characters in pattern.""" + (tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="def.*:") + + assert "/test.py" in result + + def test_grep_case_insensitive(self, tmp_path: Path) -> None: + """Test grep with case-insensitive search.""" + (tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False) + + result = middleware.grep_search.func(pattern="(?i)hello") + + assert "/test.py" in result + + def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None: + """Test that grep skips files larger than max_file_size_mb.""" + # Create a file larger than 1MB + large_content = "x" * (2 * 1024 * 1024) # 2MB + (tmp_path / "large.txt").write_text(large_content, encoding="utf-8") + (tmp_path / "small.txt").write_text("x", encoding="utf-8") + + middleware = FilesystemFileSearchMiddleware( + root_path=str(tmp_path), + use_ripgrep=False, + max_file_size_mb=1, # 1MB limit + ) + + result = middleware.grep_search.func(pattern="x") + + # Large file should be skipped + assert "/small.txt" in result diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py index 37891df9412..e534d837a79 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py @@ -1,17 +1,21 @@ from __future__ import annotations -import time +import asyncio import gc import tempfile +import time from pathlib import Path import pytest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools.base import ToolException from langchain.agents.middleware.shell_tool import ( HostExecutionPolicy, + RedactionRule, ShellToolMiddleware, _SessionResources, - RedactionRule, + _ShellToolInput, ) from langchain.agents.middleware.types import AgentState @@ -173,3 +177,294 @@ def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None: assert not finalizer.alive assert session.stopped assert not tempdir_path.exists() + + +def test_shell_tool_input_validation() -> None: + """Test _ShellToolInput validation rules.""" + # Both command and restart not allowed + with pytest.raises(ValueError, match="only one"): + _ShellToolInput(command="ls", restart=True) + + # Neither command nor restart provided + with pytest.raises(ValueError, match="requires either"): + _ShellToolInput() + + # Valid: command only + valid_cmd = _ShellToolInput(command="ls") + assert valid_cmd.command == "ls" + assert not valid_cmd.restart + + # Valid: restart only + valid_restart = _ShellToolInput(restart=True) + assert valid_restart.restart is True + assert valid_restart.command is None + + +def test_normalize_shell_command_empty() -> None: + """Test that empty shell command raises an error.""" + with pytest.raises(ValueError, match="at least one argument"): + ShellToolMiddleware(shell_command=[]) + + +def test_normalize_env_non_string_keys() -> None: + """Test that non-string environment keys raise an error.""" + with pytest.raises(TypeError, match="must be strings"): + ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item] + + +def test_normalize_env_coercion(tmp_path: Path) -> None: + """Test that environment values are coerced to strings.""" + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True} + ) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + result = middleware._run_shell_tool( + resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None + ) + assert "42" in result + assert "True" in result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_shell_tool_missing_command_string(tmp_path: Path) -> None: + """Test that shell tool raises an error when command is not a string.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + with pytest.raises(ToolException, match="expects a 'command' string"): + middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None) + + with pytest.raises(ToolException, match="expects a 'command' string"): + middleware._run_shell_tool( + resources, + {"command": 123}, # type: ignore[dict-item] + tool_call_id=None, + ) + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_tool_message_formatting_with_id(tmp_path: Path) -> None: + """Test that tool messages are properly formatted with tool_call_id.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + result = middleware._run_shell_tool( + resources, {"command": "echo test"}, tool_call_id="test-id-123" + ) + + assert isinstance(result, ToolMessage) + assert result.tool_call_id == "test-id-123" + assert result.name == "shell" + assert result.status == "success" + assert "test" in result.content + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None: + """Test that non-zero exit codes are marked as errors.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + result = middleware._run_shell_tool( + resources, + {"command": "false"}, # Command that exits with 1 but doesn't kill shell + tool_call_id="test-id", + ) + + assert isinstance(result, ToolMessage) + assert result.status == "error" + assert "Exit code: 1" in result.content + assert result.artifact["exit_code"] == 1 # type: ignore[index] + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_truncation_by_bytes(tmp_path: Path) -> None: + """Test that output is truncated by bytes when max_output_bytes is exceeded.""" + policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0) + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + result = middleware._run_shell_tool( + resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None + ) + + assert "truncated at 50 bytes" in result.lower() + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_startup_command_failure(tmp_path: Path) -> None: + """Test that startup command failure raises an error.""" + policy = HostExecutionPolicy(startup_timeout=1.0) + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy + ) + state: AgentState = _empty_state() + with pytest.raises(RuntimeError, match="Startup command.*failed"): + middleware.before_agent(state, None) + + +def test_shutdown_command_failure_logged(tmp_path: Path) -> None: + """Test that shutdown command failures are logged but don't raise.""" + policy = HostExecutionPolicy(command_timeout=1.0) + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", + shutdown_commands=("exit 1",), + execution_policy=policy, + ) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + finally: + # Should not raise despite shutdown command failing + middleware.after_agent(state, None) + + +def test_shutdown_command_timeout_logged(tmp_path: Path) -> None: + """Test that shutdown command timeouts are logged but don't raise.""" + policy = HostExecutionPolicy(command_timeout=0.1) + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", + execution_policy=policy, + shutdown_commands=("sleep 2",), + ) + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + finally: + # Should not raise despite shutdown command timing out + middleware.after_agent(state, None) + + +def test_ensure_resources_missing_state(tmp_path: Path) -> None: + """Test that _ensure_resources raises when resources are missing.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + state: AgentState = _empty_state() + + with pytest.raises(ToolException, match="Shell session resources are unavailable"): + middleware._ensure_resources(state) # type: ignore[attr-defined] + + +def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None: + """Test that empty command output is replaced with ''.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + result = middleware._run_shell_tool( + resources, + {"command": "true"}, # Command that produces no output + tool_call_id=None, + ) + + assert "" in result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +def test_stderr_output_labeling(tmp_path: Path) -> None: + """Test that stderr output is properly labeled.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + updates = middleware.before_agent(state, None) + if updates: + state.update(updates) + resources = middleware._ensure_resources(state) # type: ignore[attr-defined] + + result = middleware._run_shell_tool( + resources, {"command": "echo error >&2"}, tool_call_id=None + ) + + assert "[stderr] error" in result + finally: + updates = middleware.after_agent(state, None) + if updates: + state.update(updates) + + +@pytest.mark.parametrize( + ("startup_commands", "expected"), + [ + ("echo test", ("echo test",)), # String + (["echo test", "pwd"], ("echo test", "pwd")), # List + (("echo test",), ("echo test",)), # Tuple + (None, ()), # None + ], +) +def test_normalize_commands_string_tuple_list( + tmp_path: Path, + startup_commands: str | list[str] | tuple[str, ...] | None, + expected: tuple[str, ...], +) -> None: + """Test various command normalization formats.""" + middleware = ShellToolMiddleware( + workspace_root=tmp_path / "workspace", startup_commands=startup_commands + ) + assert middleware._startup_commands == expected # type: ignore[attr-defined] + + +def test_async_methods_delegate_to_sync(tmp_path: Path) -> None: + """Test that async methods properly delegate to sync methods.""" + middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace") + try: + state: AgentState = _empty_state() + + # Test abefore_agent + updates = asyncio.run(middleware.abefore_agent(state, None)) + if updates: + state.update(updates) + + # Test aafter_agent + asyncio.run(middleware.aafter_agent(state, None)) + finally: + pass diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index b7a3943f2e7..84c67b0402c 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage from langchain_core.outputs import ChatGeneration, ChatResult from langgraph.graph.message import REMOVE_ALL_MESSAGES -from unittest.mock import patch from langchain.agents.middleware.summarization import SummarizationMiddleware @@ -15,6 +15,35 @@ if TYPE_CHECKING: from langchain_model_profiles import ModelProfile +class MockChatModel(BaseChatModel): + """Mock chat model for testing.""" + + def invoke(self, prompt): # type: ignore[no-untyped-def] + return AIMessage(content="Generated summary") + + def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def] + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self) -> str: + return "mock" + + +class ProfileChatModel(BaseChatModel): + """Mock chat model with profile for testing.""" + + def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def] + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self) -> str: + return "mock" + + @property + def profile(self) -> "ModelProfile": + return {"max_input_tokens": 1000} + + def test_summarization_middleware_initialization() -> None: """Test SummarizationMiddleware initialization.""" model = FakeToolCallingModel() @@ -139,19 +168,7 @@ def test_summarization_middleware_tool_call_safety() -> None: def test_summarization_middleware_summary_creation() -> None: """Test SummarizationMiddleware summary creation.""" - - class MockModel(BaseChatModel): - def invoke(self, prompt): - return AIMessage(content="Generated summary") - - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self): - return "mock" - - middleware = SummarizationMiddleware(model=MockModel(), trigger=("tokens", 1000)) + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000)) # Test normal summary creation messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")] @@ -180,25 +197,16 @@ def test_summarization_middleware_summary_creation() -> None: # Test we raise warning if max_tokens_before_summary or messages_to_keep is specified with pytest.warns(DeprecationWarning, match="max_tokens_before_summary is deprecated"): - SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=500) + SummarizationMiddleware(model=MockChatModel(), max_tokens_before_summary=500) with pytest.warns(DeprecationWarning, match="messages_to_keep is deprecated"): - SummarizationMiddleware(model=MockModel(), messages_to_keep=5) + SummarizationMiddleware(model=MockChatModel(), messages_to_keep=5) def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None: """Verify disabling trim limit preserves full message sequence.""" - - class MockModel(BaseChatModel): - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self): - return "mock" - messages = [HumanMessage(content=str(i)) for i in range(10)] middleware = SummarizationMiddleware( - model=MockModel(), + model=MockChatModel(), trim_tokens_to_summarize=None, ) middleware.token_counter = lambda msgs: len(msgs) @@ -209,23 +217,10 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None: def test_summarization_middleware_profile_inference_triggers_summary() -> None: """Ensure automatic profile inference triggers summarization when limits are exceeded.""" - - class ProfileModel(BaseChatModel): - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self) -> str: - return "mock" - - @property - def profile(self) -> "ModelProfile": - return {"max_input_tokens": 1000} - token_counter = lambda messages: len(messages) * 200 middleware = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.81), keep=("fraction", 0.5), token_counter=token_counter, @@ -250,7 +245,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: # Engage summarization # 0.80 * 1000 == 800 <= 800 middleware = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.80), keep=("fraction", 0.5), token_counter=token_counter, @@ -270,7 +265,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: # With keep=("fraction", 0.6) the target token allowance becomes 600, # so the cutoff shifts to keep the last three messages instead of two. middleware = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.80), keep=("fraction", 0.6), token_counter=token_counter, @@ -287,7 +282,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: # context (target tokens = 800), so token-based retention keeps everything # and summarization is skipped entirely. middleware = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.80), keep=("fraction", 0.8), token_counter=token_counter, @@ -296,7 +291,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: # Test with tokens_to_keep as absolute int value middleware_int = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.80), keep=("tokens", 400), # Keep exactly 400 tokens (2 messages) token_counter=token_counter, @@ -310,7 +305,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: # Test with tokens_to_keep as larger int value middleware_int_large = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.80), keep=("tokens", 600), # Keep 600 tokens (3 messages) token_counter=token_counter, @@ -327,23 +322,11 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None: """Ensure token retention keeps pairs together even if exceeding target tokens.""" - class ProfileModel(BaseChatModel): - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self) -> str: - return "mock" - - @property - def profile(self) -> "ModelProfile": - return {"max_input_tokens": 1000} - - def token_counter(messages): + def token_counter(messages: list[AnyMessage]) -> int: return sum(len(getattr(message, "content", "")) for message in messages) middleware = SummarizationMiddleware( - model=ProfileModel(), + model=ProfileChatModel(), trigger=("fraction", 0.1), keep=("fraction", 0.5), ) @@ -400,22 +383,10 @@ def test_summarization_middleware_missing_profile() -> None: def test_summarization_middleware_full_workflow() -> None: """Test SummarizationMiddleware complete summarization workflow.""" - - class MockModel(BaseChatModel): - def invoke(self, prompt): - return AIMessage(content="Generated summary") - - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self): - return "mock" - with pytest.warns(DeprecationWarning): # keep test for functionality middleware = SummarizationMiddleware( - model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2 + model=MockChatModel(), max_tokens_before_summary=1000, messages_to_keep=2 ) # Mock high token count to trigger summarization @@ -504,21 +475,9 @@ async def test_summarization_middleware_full_workflow_async() -> None: def test_summarization_middleware_keep_messages() -> None: """Test SummarizationMiddleware with keep parameter specifying messages.""" - - class MockModel(BaseChatModel): - def invoke(self, prompt): - return AIMessage(content="Generated summary") - - def _generate(self, messages, **kwargs): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - - @property - def _llm_type(self): - return "mock" - # Test that summarization is triggered when message count reaches threshold middleware = SummarizationMiddleware( - model=MockModel(), trigger=("messages", 5), keep=("messages", 2) + model=MockChatModel(), trigger=("messages", 5), keep=("messages", 2) ) # Below threshold - no summarization @@ -561,6 +520,379 @@ def test_summarization_middleware_keep_messages() -> None: assert [message.content for message in result["messages"][2:]] == ["5", "6"] # Test with both parameters disabled - middleware_disabled = SummarizationMiddleware(model=MockModel(), trigger=None) + middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None) result = middleware_disabled.before_model(state_above, None) assert result is None + + +@pytest.mark.parametrize( + ("param_name", "param_value", "expected_error"), + [ + ("trigger", ("fraction", 0.0), "Fractional trigger values must be between 0 and 1"), + ("trigger", ("fraction", 1.5), "Fractional trigger values must be between 0 and 1"), + ("keep", ("fraction", -0.1), "Fractional keep values must be between 0 and 1"), + ("trigger", ("tokens", 0), "trigger thresholds must be greater than 0"), + ("trigger", ("messages", -5), "trigger thresholds must be greater than 0"), + ("keep", ("tokens", 0), "keep thresholds must be greater than 0"), + ("trigger", ("invalid", 100), "Unsupported context size type"), + ("keep", ("invalid", 100), "Unsupported context size type"), + ], +) +def test_summarization_middleware_validation_edge_cases( + param_name: str, param_value: tuple[str, float | int], expected_error: str +) -> None: + """Test validation of context size parameters with edge cases.""" + model = FakeToolCallingModel() + with pytest.raises(ValueError, match=expected_error): + SummarizationMiddleware(model=model, **{param_name: param_value}) + + +def test_summarization_middleware_multiple_triggers() -> None: + """Test middleware with multiple trigger conditions.""" + # Test with multiple triggers - should activate when ANY condition is met + middleware = SummarizationMiddleware( + model=MockChatModel(), + trigger=[("messages", 10), ("tokens", 500)], + keep=("messages", 2), + ) + + # Mock token counter to return low count + def mock_low_tokens(messages): + return 100 + + middleware.token_counter = mock_low_tokens + + # Should not trigger - neither condition met + messages = [HumanMessage(content=str(i)) for i in range(5)] + state = {"messages": messages} + result = middleware.before_model(state, None) + assert result is None + + # Should trigger - message count threshold met + messages = [HumanMessage(content=str(i)) for i in range(10)] + state = {"messages": messages} + result = middleware.before_model(state, None) + assert result is not None + + # Test token trigger + def mock_high_tokens(messages): + return 600 + + middleware.token_counter = mock_high_tokens + messages = [HumanMessage(content=str(i)) for i in range(5)] + state = {"messages": messages} + result = middleware.before_model(state, None) + assert result is not None + + +def test_summarization_middleware_profile_edge_cases() -> None: + """Test profile retrieval with various edge cases.""" + + class NoProfileModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self): + return "mock" + + # Model without profile attribute + middleware = SummarizationMiddleware(model=NoProfileModel(), trigger=("messages", 5)) + assert middleware._get_profile_limits() is None + + class InvalidProfileModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self): + return "mock" + + @property + def profile(self): + return "invalid_profile_type" + + # Model with non-dict profile + middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5)) + assert middleware._get_profile_limits() is None + + class MissingTokensModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self): + return "mock" + + @property + def profile(self): + return {"other_field": 100} + + # Model with profile but no max_input_tokens + middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5)) + assert middleware._get_profile_limits() is None + + class InvalidTokenTypeModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self): + return "mock" + + @property + def profile(self): + return {"max_input_tokens": "not_an_int"} + + # Model with non-int max_input_tokens + middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5)) + assert middleware._get_profile_limits() is None + + +def test_summarization_middleware_trim_messages_error_fallback() -> None: + """Test that trim_messages_for_summary falls back gracefully on errors.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + # Create a mock token counter that raises an exception + def failing_token_counter(messages): + raise Exception("Token counting failed") + + middleware.token_counter = failing_token_counter + + # Should fall back to last 15 messages + messages = [HumanMessage(content=str(i)) for i in range(20)] + trimmed = middleware._trim_messages_for_summary(messages) + assert len(trimmed) == 15 + assert trimmed == messages[-15:] + + +def test_summarization_middleware_binary_search_edge_cases() -> None: + """Test binary search in _find_token_based_cutoff with edge cases.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 5), keep=("tokens", 100) + ) + + # Test with single message that's too large + def token_counter_single_large(messages): + return len(messages) * 200 + + middleware.token_counter = token_counter_single_large + + single_message = [HumanMessage(content="x" * 200)] + cutoff = middleware._find_token_based_cutoff(single_message) + assert cutoff == 0 + + # Test with empty messages + cutoff = middleware._find_token_based_cutoff([]) + assert cutoff == 0 + + # Test when all messages fit within token budget + def token_counter_small(messages): + return len(messages) * 10 + + middleware.token_counter = token_counter_small + messages = [HumanMessage(content=str(i)) for i in range(5)] + cutoff = middleware._find_token_based_cutoff(messages) + assert cutoff == 0 + + +def test_summarization_middleware_tool_call_extraction_edge_cases() -> None: + """Test tool call ID extraction with various message formats.""" + model = FakeToolCallingModel() + middleware = SummarizationMiddleware(model=model, trigger=("messages", 5)) + + # Test with dict-style tool calls + ai_message_dict = AIMessage( + content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}] + ) + ids = middleware._extract_tool_call_ids(ai_message_dict) + assert ids == {"id1"} + + # Test with multiple tool calls + ai_message_multiple = AIMessage( + content="test", + tool_calls=[ + {"name": "tool1", "args": {}, "id": "id1"}, + {"name": "tool2", "args": {}, "id": "id2"}, + ], + ) + ids = middleware._extract_tool_call_ids(ai_message_multiple) + assert ids == {"id1", "id2"} + + # Test with empty tool calls list + ai_message_empty = AIMessage(content="test", tool_calls=[]) + ids = middleware._extract_tool_call_ids(ai_message_empty) + assert len(ids) == 0 + + +def test_summarization_middleware_complex_tool_pair_scenarios() -> None: + """Test complex tool call pairing scenarios.""" + model = FakeToolCallingModel() + middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3)) + + # Test with multiple AI messages with tool calls + messages = [ + HumanMessage(content="msg1"), + AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]), + ToolMessage(content="result1", tool_call_id="call1"), + HumanMessage(content="msg2"), + AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]), + ToolMessage(content="result2", tool_call_id="call2"), + HumanMessage(content="msg3"), + ] + + # Test cutoff at index 1 - unsafe (separates first AI/Tool pair) + assert not middleware._is_safe_cutoff_point(messages, 2) + + # Test cutoff at index 3 - safe (keeps first pair together) + assert middleware._is_safe_cutoff_point(messages, 3) + + # Test cutoff at index 5 - unsafe (separates second AI/Tool pair) + assert not middleware._is_safe_cutoff_point(messages, 5) + + # Test _cutoff_separates_tool_pair directly + assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"}) + assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"}) + assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"}) + + +def test_summarization_middleware_tool_call_in_search_range() -> None: + """Test tool call safety with messages at edge of search range.""" + model = FakeToolCallingModel() + middleware = SummarizationMiddleware( + model=model, trigger=("messages", 10), keep=("messages", 2) + ) + + # Create messages with tool pair separated by some distance + # Search range is 5, so messages within 5 positions of cutoff are checked + messages = [ + HumanMessage(content="msg1"), + HumanMessage(content="msg2"), + AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), + HumanMessage(content="msg3"), + HumanMessage(content="msg4"), + ToolMessage(content="result", tool_call_id="call1"), + HumanMessage(content="msg6"), + ] + + # Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6] + # AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe + assert not middleware._is_safe_cutoff_point(messages, 3) + + # Cutoff at index 6 keeps AI and Tool both in summarized section + assert middleware._is_safe_cutoff_point(messages, 6) + + # Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section + assert middleware._is_safe_cutoff_point(messages, 0) + assert middleware._is_safe_cutoff_point(messages, 1) + + +def test_summarization_middleware_zero_and_negative_target_tokens() -> None: + """Test handling of edge cases with target token calculations.""" + # Test with very small fraction that rounds to zero + middleware = SummarizationMiddleware( + model=ProfileChatModel(), trigger=("fraction", 0.0001), keep=("fraction", 0.0001) + ) + + # Should set threshold to 1 when calculated value is <= 0 + messages = [HumanMessage(content="test")] + state = {"messages": messages} + + # The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1 + # Token count of 1 message should exceed threshold of 1 + def token_counter(msgs): + return 2 + + middleware.token_counter = token_counter + assert middleware._should_summarize(messages, 2) + + +async def test_summarization_middleware_async_error_handling() -> None: + """Test async summary creation with errors.""" + + class ErrorAsyncModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + async def _agenerate(self, messages, **kwargs): + raise Exception("Async model error") + + @property + def _llm_type(self): + return "mock" + + middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5)) + messages = [HumanMessage(content="test")] + summary = await middleware._acreate_summary(messages) + assert "Error generating summary: Async model error" in summary + + +def test_summarization_middleware_cutoff_at_boundary() -> None: + """Test cutoff index determination at exact message boundaries.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 5), keep=("messages", 5) + ) + + # When we want to keep exactly as many messages as we have + messages = [HumanMessage(content=str(i)) for i in range(5)] + cutoff = middleware._find_safe_cutoff(messages, 5) + assert cutoff == 0 # Should not cut anything + + # When we want to keep more messages than we have + cutoff = middleware._find_safe_cutoff(messages, 10) + assert cutoff == 0 + + +def test_summarization_middleware_deprecated_parameters_with_defaults() -> None: + """Test that deprecated parameters work correctly with default values.""" + # Test that deprecated max_tokens_before_summary is ignored when trigger is set + with pytest.warns(DeprecationWarning): + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("tokens", 2000), max_tokens_before_summary=1000 + ) + assert middleware.trigger == ("tokens", 2000) + + # Test that messages_to_keep is ignored when keep is not default + with pytest.warns(DeprecationWarning): + middleware = SummarizationMiddleware( + model=MockChatModel(), keep=("messages", 5), messages_to_keep=10 + ) + assert middleware.keep == ("messages", 5) + + +def test_summarization_middleware_fraction_trigger_with_no_profile() -> None: + """Test fractional trigger condition when profile data becomes unavailable.""" + middleware = SummarizationMiddleware( + model=ProfileChatModel(), + trigger=[("fraction", 0.5), ("messages", 100)], + keep=("messages", 5), + ) + + # Test that when fractional condition can't be evaluated, other triggers still work + messages = [HumanMessage(content=str(i)) for i in range(100)] + + # Mock _get_profile_limits to return None + original_method = middleware._get_profile_limits + middleware._get_profile_limits = lambda: None + + # Should still trigger based on message count + state = {"messages": messages} + result = middleware.before_model(state, None) + assert result is not None + + # Restore original method + middleware._get_profile_limits = original_method + + +def test_summarization_middleware_is_safe_cutoff_at_end() -> None: + """Test _is_safe_cutoff_point when cutoff is at or past the end.""" + model = FakeToolCallingModel() + middleware = SummarizationMiddleware(model=model, trigger=("messages", 5)) + + messages = [HumanMessage(content=str(i)) for i in range(5)] + + # Cutoff at exactly the length should be safe + assert middleware._is_safe_cutoff_point(messages, len(messages)) + + # Cutoff past the length should also be safe + assert middleware._is_safe_cutoff_point(messages, len(messages) + 5) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py index 031bb60fe2d..1a8c6422453 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py @@ -1,22 +1,22 @@ """Unit tests for LLM tool selection middleware.""" import typing -from typing import Union, Any, Literal - from itertools import cycle +from typing import Any, Literal, Union + +import pytest from pydantic import BaseModel from langchain.agents import create_agent -from langchain.agents.middleware import AgentState, ModelRequest, wrap_model_call -from langchain.agents.middleware import LLMToolSelectorMiddleware +from langchain.agents.middleware import LLMToolSelectorMiddleware, ModelRequest, wrap_model_call +from langchain.agents.middleware.tool_selection import _create_tool_selection_response +from langchain.agents.middleware.types import AgentState from langchain.messages import AIMessage from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.fake_chat_models import GenericFakeChatModel -from langchain_core.messages import BaseMessage -from langchain_core.messages import HumanMessage +from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.runnables import Runnable -from langchain_core.tools import BaseTool -from langchain_core.tools import tool +from langchain_core.tools import BaseTool, tool @tool @@ -596,3 +596,12 @@ class TestDuplicateAndInvalidTools: assert len(tool_names) == 2 assert "get_weather" in tool_names assert "search_web" in tool_names + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_empty_tools_list_raises_error(self) -> None: + """Test that empty tools list raises an error in schema creation.""" + with pytest.raises(AssertionError, match="tools must be non-empty"): + _create_tool_selection_response([])