Compare commits

...

2 Commits

Author SHA1 Message Date
Nuno Campos
6b6eeaebde Guard against malformed or malicious include patterns 2025-10-15 20:04:08 +01:00
Nuno Campos
672b8eceb8 feat(langchain_v1): Add Anthropic tools middleware with text editor, memory, and file
search

Middleware Classes

Text Editor Tools
- StateClaudeTextEditorToolMiddleware: In-memory text editor using agent state
- FilesystemClaudeTextEditorToolMiddleware: Text editor operating on real filesystem

Implementing Claude's text editor tools
https://docs.claude.com/en/docs/agents-and-tools/tool-use/text-editor-tool Operations:
view, create, str_replace, insert

Memory Tools
- StateClaudeMemoryToolMiddleware: Memory persistence in agent state
- FilesystemClaudeMemoryToolMiddleware: Memory persistence on filesystem

Implementing Claude's memory tools
https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool Operations: Same
as text editor plus delete and rename

File Search Tools
- StateFileSearchMiddleware: Search state-based files
- FilesystemFileSearchMiddleware: Search real filesystem

Provides Glob and Grep tools with same schema as used by Claude Code (but compatible
with any model)
- Glob: Pattern matching (e.g., **/*.py, src/**/*.ts), sorted by modification time
- Grep: Regex content search with output modes (files_with_matches, content, count)

Usage

``` from langchain.agents import create_agent from langchain.agents.middleware import (
StateTextEditorToolMiddleware, StateFileSearchMiddleware, )

agent = create_agent( model=model, tools=[], middleware=[
StateTextEditorToolMiddleware(), StateFileSearchMiddleware(), ], ) ```
2025-10-15 19:53:33 +01:00
5 changed files with 2439 additions and 0 deletions

View File

@@ -1,9 +1,16 @@
"""Middleware plugins for agents."""
from .anthropic_tools import (
FilesystemClaudeMemoryMiddleware,
FilesystemClaudeTextEditorMiddleware,
StateClaudeMemoryMiddleware,
StateClaudeTextEditorMiddleware,
)
from .context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from .file_search import FilesystemFileSearchMiddleware, StateFileSearchMiddleware
from .human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
@@ -36,6 +43,9 @@ __all__ = [
"AgentState",
"ClearToolUsesEdit",
"ContextEditingMiddleware",
"FilesystemClaudeMemoryMiddleware",
"FilesystemClaudeTextEditorMiddleware",
"FilesystemFileSearchMiddleware",
"HumanInTheLoopMiddleware",
"InterruptOnConfig",
"LLMToolEmulator",
@@ -46,6 +56,9 @@ __all__ = [
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"StateClaudeMemoryMiddleware",
"StateClaudeTextEditorMiddleware",
"StateFileSearchMiddleware",
"SummarizationMiddleware",
"TodoListMiddleware",
"ToolCallLimitMiddleware",

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,588 @@
"""File search middleware for Anthropic text editor and memory tools.
This module provides Glob and Grep search tools that operate on files stored
in state or filesystem.
"""
from __future__ import annotations
import fnmatch
import json
import re
import subprocess
from contextlib import suppress
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from typing import Annotated, Any, Literal, cast
from langchain_core.tools import InjectedToolArg, tool
from langchain.agents.middleware.anthropic_tools import AnthropicToolsState
from langchain.agents.middleware.types import AgentMiddleware
def _expand_include_patterns(pattern: str) -> list[str] | None:
"""Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
if "}" in pattern and "{" not in pattern:
return None
expanded: list[str] = []
def _expand(current: str) -> None:
start = current.find("{")
if start == -1:
expanded.append(current)
return
end = current.find("}", start)
if end == -1:
raise ValueError
prefix = current[:start]
suffix = current[end + 1 :]
inner = current[start + 1 : end]
if not inner:
raise ValueError
for option in inner.split(","):
_expand(prefix + option + suffix)
try:
_expand(pattern)
except ValueError:
return None
return expanded
def _is_valid_include_pattern(pattern: str) -> bool:
"""Validate glob pattern used for include filters."""
if not pattern:
return False
if any(char in pattern for char in ("\x00", "\n", "\r")):
return False
expanded = _expand_include_patterns(pattern)
if expanded is None:
return False
try:
for candidate in expanded:
re.compile(fnmatch.translate(candidate))
except re.error:
return False
return True
def _match_include_pattern(basename: str, pattern: str) -> bool:
"""Return True if the basename matches the include pattern."""
expanded = _expand_include_patterns(pattern)
if not expanded:
return False
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)
class StateFileSearchMiddleware(AgentMiddleware):
"""Provides Glob and Grep search over state-based files.
This middleware adds two tools that search through virtual files in state:
- Glob: Fast file pattern matching by file path
- Grep: Fast content search using regular expressions
Example:
```python
from langchain.agents import create_agent
from langchain.agents.middleware import (
StateTextEditorToolMiddleware,
StateFileSearchMiddleware,
)
agent = create_agent(
model=model,
tools=[],
middleware=[
StateTextEditorToolMiddleware(),
StateFileSearchMiddleware(),
],
)
```
"""
state_schema = AnthropicToolsState
def __init__(
self,
*,
state_key: str = "text_editor_files",
) -> None:
"""Initialize the search middleware.
Args:
state_key: State key to search (default: "text_editor_files").
Use "memory_files" to search memory tool files.
"""
self.state_key = state_key
# Create tool instances
@tool
def glob_search( # noqa: D417
pattern: str,
path: str = "/",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast file pattern matching tool that works with any codebase size.
Supports glob patterns like **/*.js or src/**/*.ts.
Returns matching file paths sorted by modification time.
Use this tool when you need to find files by name patterns.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in. If not specified, searches from root.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
matches.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Get files from state
files = cast("dict[str, Any]", state.get(self.state_key, {}))
# Match files
matches = []
for file_path, file_data in files.items():
if file_path.startswith(base_path):
# Get relative path from base
if base_path == "/":
relative = file_path[1:] # Remove leading /
elif file_path == base_path:
relative = Path(file_path).name
elif file_path.startswith(base_path + "/"):
relative = file_path[len(base_path) + 1 :]
else:
continue
# Match against pattern
# Handle ** pattern which requires special care
# PurePosixPath.match doesn't match single-level paths against **/pattern
is_match = PurePosixPath(relative).match(pattern)
if not is_match and pattern.startswith("**/"):
# Also try matching without the **/ prefix for files in base dir
is_match = PurePosixPath(relative).match(pattern[3:])
if is_match:
matches.append((file_path, file_data["modified_at"]))
if not matches:
return "No files found"
# Sort by modification time
matches.sort(key=lambda x: x[1], reverse=True)
file_paths = [path for path, _ in matches]
return "\n".join(file_paths)
@tool
def grep_search( # noqa: D417
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast content search tool that works with any codebase size.
Searches file contents using regular expressions. Supports full regex
syntax and filters files by pattern with the include parameter.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format:
- "files_with_matches": Only file paths containing matches (default)
- "content": Matching lines with file:line:content format
- "count": Count of matches per file
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Compile regex pattern (for validation)
try:
regex = re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Search files
files = cast("dict[str, Any]", state.get(self.state_key, {}))
results: dict[str, list[tuple[int, str]]] = {}
for file_path, file_data in files.items():
if not file_path.startswith(base_path):
continue
# Check include filter
if include:
basename = Path(file_path).name
if not _match_include_pattern(basename, include):
continue
# Search file content
for line_num, line in enumerate(file_data["content"], 1):
if regex.search(line):
if file_path not in results:
results[file_path] = []
results[file_path].append((line_num, line))
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _format_grep_results(
self,
results: dict[str, list[tuple[int, str]]],
output_mode: str,
) -> str:
"""Format grep results based on output mode."""
if output_mode == "files_with_matches":
# Just return file paths
return "\n".join(sorted(results.keys()))
if output_mode == "content":
# Return file:line:content format
lines = []
for file_path in sorted(results.keys()):
for line_num, line in results[file_path]:
lines.append(f"{file_path}:{line_num}:{line}")
return "\n".join(lines)
if output_mode == "count":
# Return file:count format
lines = []
for file_path in sorted(results.keys()):
count = len(results[file_path])
lines.append(f"{file_path}:{count}")
return "\n".join(lines)
# Default to files_with_matches
return "\n".join(sorted(results.keys()))
class FilesystemFileSearchMiddleware(AgentMiddleware):
"""Provides Glob and Grep search over filesystem files.
This middleware adds two tools that search through local filesystem:
- Glob: Fast file pattern matching by file path
- Grep: Fast content search using ripgrep or Python fallback
Example:
```python
from langchain.agents import create_agent
from langchain.agents.middleware import (
FilesystemTextEditorToolMiddleware,
FilesystemFileSearchMiddleware,
)
agent = create_agent(
model=model,
tools=[],
middleware=[
FilesystemTextEditorToolMiddleware(root_path="/workspace"),
FilesystemFileSearchMiddleware(root_path="/workspace"),
],
)
```
"""
def __init__(
self,
*,
root_path: str,
use_ripgrep: bool = True,
max_file_size_mb: int = 10,
) -> None:
"""Initialize the search middleware.
Args:
root_path: Root directory to search.
use_ripgrep: Whether to use ripgrep for search (default: True).
Falls back to Python if ripgrep unavailable.
max_file_size_mb: Maximum file size to search in MB (default: 10).
"""
self.root_path = Path(root_path).resolve()
self.use_ripgrep = use_ripgrep
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
# Create tool instances as closures that capture self
@tool
def glob_search(pattern: str, path: str = "/") -> str:
"""Fast file pattern matching tool that works with any codebase size.
Supports glob patterns like **/*.js or src/**/*.ts.
Returns matching file paths sorted by modification time.
Use this tool when you need to find files by name patterns.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in. If not specified, searches from root.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
matches.
"""
try:
base_full = self._validate_and_resolve_path(path)
except ValueError:
return "No files found"
if not base_full.exists() or not base_full.is_dir():
return "No files found"
# Use pathlib glob
matching: list[tuple[str, str]] = []
for match in base_full.glob(pattern):
if match.is_file():
# Convert to virtual path
virtual_path = "/" + str(match.relative_to(self.root_path))
stat = match.stat()
modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat()
matching.append((virtual_path, modified_at))
if not matching:
return "No files found"
file_paths = [p for p, _ in matching]
return "\n".join(file_paths)
@tool
def grep_search(
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
) -> str:
"""Fast content search tool that works with any codebase size.
Searches file contents using regular expressions. Supports full regex
syntax and filters files by pattern with the include parameter.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format:
- "files_with_matches": Only file paths containing matches (default)
- "content": Matching lines with file:line:content format
- "count": Count of matches per file
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Compile regex pattern (for validation)
try:
re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Try ripgrep first if enabled
results = None
if self.use_ripgrep:
with suppress(
FileNotFoundError,
subprocess.CalledProcessError,
subprocess.TimeoutExpired,
):
results = self._ripgrep_search(pattern, path, include)
# Python fallback if ripgrep failed or is disabled
if results is None:
results = self._python_search(pattern, path, include)
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _validate_and_resolve_path(self, path: str) -> Path:
"""Validate and resolve a virtual path to filesystem path."""
# Normalize path
if not path.startswith("/"):
path = "/" + path
# Check for path traversal
if ".." in path or "~" in path:
msg = "Path traversal not allowed"
raise ValueError(msg)
# Convert virtual path to filesystem path
relative = path.lstrip("/")
full_path = (self.root_path / relative).resolve()
# Ensure path is within root
try:
full_path.relative_to(self.root_path)
except ValueError:
msg = f"Path outside root directory: {path}"
raise ValueError(msg) from None
return full_path
def _ripgrep_search(
self, pattern: str, base_path: str, include: str | None
) -> dict[str, list[tuple[int, str]]]:
"""Search using ripgrep subprocess."""
try:
base_full = self._validate_and_resolve_path(base_path)
except ValueError:
return {}
if not base_full.exists():
return {}
# Build ripgrep command
cmd = ["rg", "--json"]
if include:
# Convert glob pattern to ripgrep glob
cmd.extend(["--glob", include])
cmd.extend(["--", pattern, str(base_full)])
try:
result = subprocess.run( # noqa: S603
cmd,
capture_output=True,
text=True,
timeout=30,
check=False,
)
except (subprocess.TimeoutExpired, FileNotFoundError):
# Fallback to Python search if ripgrep unavailable or times out
return self._python_search(pattern, base_path, include)
# Parse ripgrep JSON output
results: dict[str, list[tuple[int, str]]] = {}
for line in result.stdout.splitlines():
try:
data = json.loads(line)
if data["type"] == "match":
path = data["data"]["path"]["text"]
# Convert to virtual path
virtual_path = "/" + str(Path(path).relative_to(self.root_path))
line_num = data["data"]["line_number"]
line_text = data["data"]["lines"]["text"].rstrip("\n")
if virtual_path not in results:
results[virtual_path] = []
results[virtual_path].append((line_num, line_text))
except (json.JSONDecodeError, KeyError):
continue
return results
def _python_search(
self, pattern: str, base_path: str, include: str | None
) -> dict[str, list[tuple[int, str]]]:
"""Search using Python regex (fallback)."""
try:
base_full = self._validate_and_resolve_path(base_path)
except ValueError:
return {}
if not base_full.exists():
return {}
regex = re.compile(pattern)
results: dict[str, list[tuple[int, str]]] = {}
# Walk directory tree
for file_path in base_full.rglob("*"):
if not file_path.is_file():
continue
# Check include filter
if include and not _match_include_pattern(file_path.name, include):
continue
# Skip files that are too large
if file_path.stat().st_size > self.max_file_size_bytes:
continue
try:
content = file_path.read_text()
except (UnicodeDecodeError, PermissionError):
continue
# Search content
for line_num, line in enumerate(content.splitlines(), 1):
if regex.search(line):
virtual_path = "/" + str(file_path.relative_to(self.root_path))
if virtual_path not in results:
results[virtual_path] = []
results[virtual_path].append((line_num, line))
return results
def _format_grep_results(
self,
results: dict[str, list[tuple[int, str]]],
output_mode: str,
) -> str:
"""Format grep results based on output mode."""
if output_mode == "files_with_matches":
# Just return file paths
return "\n".join(sorted(results.keys()))
if output_mode == "content":
# Return file:line:content format
lines = []
for file_path in sorted(results.keys()):
for line_num, line in results[file_path]:
lines.append(f"{file_path}:{line_num}:{line}")
return "\n".join(lines)
if output_mode == "count":
# Return file:count format
lines = []
for file_path in sorted(results.keys()):
count = len(results[file_path])
lines.append(f"{file_path}:{count}")
return "\n".join(lines)
# Default to files_with_matches
return "\n".join(sorted(results.keys()))
__all__ = [
"FilesystemFileSearchMiddleware",
"StateFileSearchMiddleware",
]

View File

@@ -0,0 +1,276 @@
"""Unit tests for Anthropic text editor and memory tool middleware."""
import pytest
from langchain.agents.middleware.anthropic_tools import (
AnthropicToolsState,
StateClaudeMemoryMiddleware,
StateClaudeTextEditorMiddleware,
_validate_path,
)
from langchain_core.messages import ToolMessage
from langgraph.types import Command
class TestPathValidation:
"""Test path validation and security."""
def test_basic_path_normalization(self) -> None:
"""Test basic path normalization."""
assert _validate_path("/foo/bar") == "/foo/bar"
assert _validate_path("foo/bar") == "/foo/bar"
assert _validate_path("/foo//bar") == "/foo/bar"
assert _validate_path("/foo/./bar") == "/foo/bar"
def test_path_traversal_blocked(self) -> None:
"""Test that path traversal attempts are blocked."""
with pytest.raises(ValueError, match="Path traversal not allowed"):
_validate_path("/foo/../etc/passwd")
with pytest.raises(ValueError, match="Path traversal not allowed"):
_validate_path("../etc/passwd")
with pytest.raises(ValueError, match="Path traversal not allowed"):
_validate_path("~/.ssh/id_rsa")
def test_allowed_prefixes(self) -> None:
"""Test path prefix validation."""
# Should pass
assert (
_validate_path("/workspace/file.txt", allowed_prefixes=["/workspace"])
== "/workspace/file.txt"
)
# Should fail
with pytest.raises(ValueError, match="Path must start with"):
_validate_path("/etc/passwd", allowed_prefixes=["/workspace"])
with pytest.raises(ValueError, match="Path must start with"):
_validate_path("/workspacemalicious/file.txt", allowed_prefixes=["/workspace/"])
def test_memories_prefix(self) -> None:
"""Test /memories prefix validation for memory tools."""
assert (
_validate_path("/memories/notes.txt", allowed_prefixes=["/memories"])
== "/memories/notes.txt"
)
with pytest.raises(ValueError, match="Path must start with"):
_validate_path("/other/notes.txt", allowed_prefixes=["/memories"])
class TestTextEditorMiddleware:
"""Test text editor middleware functionality."""
def test_middleware_initialization(self) -> None:
"""Test middleware initializes correctly."""
middleware = StateClaudeTextEditorMiddleware()
assert middleware.state_schema == AnthropicToolsState
assert middleware.tool_type == "text_editor_20250728"
assert middleware.tool_name == "str_replace_based_edit_tool"
assert middleware.state_key == "text_editor_files"
# With path restrictions
middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"])
assert middleware.allowed_prefixes == ["/workspace"]
class TestMemoryMiddleware:
"""Test memory middleware functionality."""
def test_middleware_initialization(self) -> None:
"""Test middleware initializes correctly."""
middleware = StateClaudeMemoryMiddleware()
assert middleware.state_schema == AnthropicToolsState
assert middleware.tool_type == "memory_20250818"
assert middleware.tool_name == "memory"
assert middleware.state_key == "memory_files"
assert middleware.system_prompt # Should have default prompt
def test_custom_system_prompt(self) -> None:
"""Test custom system prompt can be set."""
custom_prompt = "Custom memory instructions"
middleware = StateClaudeMemoryMiddleware(system_prompt=custom_prompt)
assert middleware.system_prompt == custom_prompt
class TestFileOperations:
"""Test file operation implementations via wrap_tool_call."""
def test_view_operation(self) -> None:
"""Test view command execution."""
middleware = StateClaudeTextEditorMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/test.txt": {
"content": ["line1", "line2", "line3"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
args = {"command": "view", "path": "/test.txt"}
result = middleware._handle_view(args, state, "test_id")
assert isinstance(result, Command)
assert result.update is not None
messages = result.update.get("messages", [])
assert len(messages) == 1
assert isinstance(messages[0], ToolMessage)
assert messages[0].content == "1|line1\n2|line2\n3|line3"
assert messages[0].tool_call_id == "test_id"
def test_create_operation(self) -> None:
"""Test create command execution."""
middleware = StateClaudeTextEditorMiddleware()
state: AnthropicToolsState = {"messages": []}
args = {"command": "create", "path": "/test.txt", "file_text": "line1\nline2"}
result = middleware._handle_create(args, state, "test_id")
assert isinstance(result, Command)
assert result.update is not None
files = result.update.get("text_editor_files", {})
assert "/test.txt" in files
assert files["/test.txt"]["content"] == ["line1", "line2"]
def test_path_prefix_enforcement(self) -> None:
"""Test that path prefixes are enforced."""
middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"])
state: AnthropicToolsState = {"messages": []}
# Should fail with /etc/passwd
args = {"command": "create", "path": "/etc/passwd", "file_text": "test"}
try:
middleware._handle_create(args, state, "test_id")
assert False, "Should have raised ValueError"
except ValueError as e:
assert "Path must start with" in str(e)
def test_memories_prefix_enforcement(self) -> None:
"""Test that /memories prefix is enforced for memory middleware."""
middleware = StateClaudeMemoryMiddleware()
state: AnthropicToolsState = {"messages": []}
# Should fail with /other/path
args = {"command": "create", "path": "/other/path.txt", "file_text": "test"}
try:
middleware._handle_create(args, state, "test_id")
assert False, "Should have raised ValueError"
except ValueError as e:
assert "/memories" in str(e)
def test_str_replace_operation(self) -> None:
"""Test str_replace command execution."""
middleware = StateClaudeTextEditorMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/test.txt": {
"content": ["Hello world", "Goodbye world"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
args = {
"command": "str_replace",
"path": "/test.txt",
"old_str": "world",
"new_str": "universe",
}
result = middleware._handle_str_replace(args, state, "test_id")
assert isinstance(result, Command)
files = result.update.get("text_editor_files", {})
# Should only replace first occurrence
assert files["/test.txt"]["content"] == ["Hello universe", "Goodbye world"]
def test_insert_operation(self) -> None:
"""Test insert command execution."""
middleware = StateClaudeTextEditorMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/test.txt": {
"content": ["line1", "line2"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
args = {
"command": "insert",
"path": "/test.txt",
"insert_line": 0,
"new_str": "inserted",
}
result = middleware._handle_insert(args, state, "test_id")
assert isinstance(result, Command)
files = result.update.get("text_editor_files", {})
assert files["/test.txt"]["content"] == ["inserted", "line1", "line2"]
def test_delete_operation(self) -> None:
"""Test delete command execution (memory only)."""
middleware = StateClaudeMemoryMiddleware()
state: AnthropicToolsState = {
"messages": [],
"memory_files": {
"/memories/test.txt": {
"content": ["line1"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
args = {"command": "delete", "path": "/memories/test.txt"}
result = middleware._handle_delete(args, state, "test_id")
assert isinstance(result, Command)
files = result.update.get("memory_files", {})
# Deleted files are marked as None in state
assert files.get("/memories/test.txt") is None
def test_rename_operation(self) -> None:
"""Test rename command execution (memory only)."""
middleware = StateClaudeMemoryMiddleware()
state: AnthropicToolsState = {
"messages": [],
"memory_files": {
"/memories/old.txt": {
"content": ["line1"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
args = {
"command": "rename",
"old_path": "/memories/old.txt",
"new_path": "/memories/new.txt",
}
result = middleware._handle_rename(args, state, "test_id")
assert isinstance(result, Command)
files = result.update.get("memory_files", {})
# Old path is marked as None (deleted)
assert files.get("/memories/old.txt") is None
# New path has the file data
assert files.get("/memories/new.txt") is not None
assert files["/memories/new.txt"]["content"] == ["line1"]

View File

@@ -0,0 +1,530 @@
"""Unit tests for file search middleware."""
from pathlib import Path
from typing import Any
import pytest
from langchain.agents.middleware.anthropic_tools import AnthropicToolsState
from langchain.agents.middleware.file_search import (
FilesystemFileSearchMiddleware,
StateFileSearchMiddleware,
)
from langchain_core.messages import ToolMessage
class TestSearchMiddlewareInitialization:
"""Test search middleware initialization."""
def test_middleware_initialization(self) -> None:
"""Test middleware initializes correctly."""
middleware = StateFileSearchMiddleware()
assert middleware.state_schema == AnthropicToolsState
assert middleware.state_key == "text_editor_files"
def test_custom_state_key(self) -> None:
"""Test middleware with custom state key."""
middleware = StateFileSearchMiddleware(state_key="memory_files")
assert middleware.state_key == "memory_files"
class TestGlobSearch:
"""Test Glob file pattern matching."""
def test_glob_basic_pattern(self) -> None:
"""Test basic glob pattern matching."""
middleware = StateFileSearchMiddleware()
test_state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["print('hello')"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/utils.py": {
"content": ["def helper(): pass"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/README.md": {
"content": ["# Project"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
# Call tool function directly (state is injected in real usage)
result = middleware.glob_search.func(pattern="*.py", state=test_state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/src/utils.py" in result
assert "/README.md" not in result
def test_glob_recursive_pattern(self) -> None:
"""Test recursive glob pattern matching."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/utils/helper.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/tests/test_main.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.glob_search.func(pattern="**/*.py", state=state)
assert isinstance(result, str)
lines = result.split("\n")
assert len(lines) == 3
assert all(".py" in line for line in lines)
def test_glob_with_base_path(self) -> None:
"""Test glob with base path restriction."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/tests/test.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.glob_search.func(pattern="**/*.py", path="/src", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/tests/test.py" not in result
def test_glob_no_matches(self) -> None:
"""Test glob with no matching files."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.glob_search.func(pattern="*.ts", state=state)
assert isinstance(result, str)
assert result == "No files found"
def test_glob_sorts_by_modified_time(self) -> None:
"""Test that glob results are sorted by modification time."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/old.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/new.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-02T00:00:00",
},
},
}
result = middleware.glob_search.func(pattern="*.py", state=state)
lines = result.split("\n")
# Most recent first
assert lines[0] == "/new.py"
assert lines[1] == "/old.py"
class TestGrepSearch:
"""Test Grep content search."""
def test_grep_files_with_matches_mode(self) -> None:
"""Test grep with files_with_matches output mode."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["def foo():", " pass"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/utils.py": {
"content": ["def bar():", " return None"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/README.md": {
"content": ["# Documentation", "No code here"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/src/utils.py" in result
assert "/README.md" not in result
# Should only have file paths, not line content
def test_grep_invalid_include_pattern(self) -> None:
"""Return error when include glob is invalid."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["def foo():"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
}
},
}
result = middleware.grep_search.func(pattern=r"def", include="*.{py", state=state)
assert result == "Invalid include pattern"
class TestFilesystemGrepSearch:
"""Tests for filesystem-backed grep search."""
def test_grep_invalid_include_pattern(self, tmp_path: Path) -> None:
"""Return error when include glob cannot be parsed."""
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="print", include="*.{py")
assert result == "Invalid include pattern"
def test_ripgrep_command_uses_literal_pattern(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Ensure ripgrep receives pattern after ``--`` to avoid option parsing."""
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=True)
captured: dict[str, list[str]] = {}
class DummyResult:
stdout = ""
def fake_run(*args: Any, **kwargs: Any) -> DummyResult:
cmd = args[0]
captured["cmd"] = cmd
return DummyResult()
monkeypatch.setattr("langchain.agents.middleware.file_search.subprocess.run", fake_run)
middleware._ripgrep_search("--pattern", "/", None)
assert "cmd" in captured
cmd = captured["cmd"]
assert cmd[:2] == ["rg", "--json"]
assert "--" in cmd
separator_index = cmd.index("--")
assert cmd[separator_index + 1] == "--pattern"
def test_grep_content_mode(self) -> None:
"""Test grep with content output mode."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["def foo():", " pass", "def bar():"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(
pattern=r"def \w+\(\):", output_mode="content", state=state
)
assert isinstance(result, str)
lines = result.split("\n")
assert len(lines) == 2
assert lines[0] == "/src/main.py:1:def foo():"
assert lines[1] == "/src/main.py:3:def bar():"
def test_grep_count_mode(self) -> None:
"""Test grep with count output mode."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["TODO: fix this", "print('hello')", "TODO: add tests"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/utils.py": {
"content": ["TODO: implement"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern=r"TODO", output_mode="count", state=state)
assert isinstance(result, str)
lines = result.split("\n")
assert "/src/main.py:2" in lines
assert "/src/utils.py:1" in lines
def test_grep_with_include_filter(self) -> None:
"""Test grep with include file pattern filter."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["import os"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/main.ts": {
"content": ["import os from 'os'"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern="import", include="*.py", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/src/main.ts" not in result
def test_grep_with_brace_expansion_filter(self) -> None:
"""Test grep with brace expansion in include filter."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.ts": {
"content": ["const x = 1"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/App.tsx": {
"content": ["const y = 2"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/src/main.py": {
"content": ["z = 3"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern="const", include="*.{ts,tsx}", state=state)
assert isinstance(result, str)
assert "/src/main.ts" in result
assert "/src/App.tsx" in result
assert "/src/main.py" not in result
def test_grep_with_base_path(self) -> None:
"""Test grep with base path restriction."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["import foo"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
"/tests/test.py": {
"content": ["import foo"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern="import", path="/src", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/tests/test.py" not in result
def test_grep_no_matches(self) -> None:
"""Test grep with no matching content."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["print('hello')"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state)
assert isinstance(result, str)
assert result == "No matches found"
def test_grep_invalid_regex(self) -> None:
"""Test grep with invalid regex pattern."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {},
}
result = middleware.grep_search.func(pattern=r"[unclosed", state=state)
assert isinstance(result, str)
assert "Invalid regex pattern" in result
class TestSearchWithDifferentBackends:
"""Test searching with different backend configurations."""
def test_glob_default_backend(self) -> None:
"""Test that glob searches the default backend (text_editor_files)."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
"memory_files": {
"/memories/notes.txt": {
"content": [],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.glob_search.func(pattern="**/*", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
# Should NOT find memory_files since default backend is text_editor_files
assert "/memories/notes.txt" not in result
def test_grep_default_backend(self) -> None:
"""Test that grep searches the default backend (text_editor_files)."""
middleware = StateFileSearchMiddleware()
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["TODO: implement"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
"memory_files": {
"/memories/tasks.txt": {
"content": ["TODO: review"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
# Should NOT find memory_files since default backend is text_editor_files
assert "/memories/tasks.txt" not in result
def test_search_with_single_store(self) -> None:
"""Test searching with a specific state key."""
middleware = StateFileSearchMiddleware(state_key="text_editor_files")
state: AnthropicToolsState = {
"messages": [],
"text_editor_files": {
"/src/main.py": {
"content": ["code"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
"memory_files": {
"/memories/notes.txt": {
"content": ["notes"],
"created_at": "2025-01-01T00:00:00",
"modified_at": "2025-01-01T00:00:00",
},
},
}
result = middleware.grep_search.func(pattern=r".*", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
assert "/memories/notes.txt" not in result