mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
2 Commits
mdrxy/mode
...
nc/9oct/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b6eeaebde | ||
|
|
672b8eceb8 |
@@ -1,9 +1,16 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .anthropic_tools import (
|
||||
FilesystemClaudeMemoryMiddleware,
|
||||
FilesystemClaudeTextEditorMiddleware,
|
||||
StateClaudeMemoryMiddleware,
|
||||
StateClaudeTextEditorMiddleware,
|
||||
)
|
||||
from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
)
|
||||
from .file_search import FilesystemFileSearchMiddleware, StateFileSearchMiddleware
|
||||
from .human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
@@ -36,6 +43,9 @@ __all__ = [
|
||||
"AgentState",
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"FilesystemClaudeMemoryMiddleware",
|
||||
"FilesystemClaudeTextEditorMiddleware",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
@@ -46,6 +56,9 @@ __all__ = [
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"StateClaudeMemoryMiddleware",
|
||||
"StateClaudeTextEditorMiddleware",
|
||||
"StateFileSearchMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"TodoListMiddleware",
|
||||
"ToolCallLimitMiddleware",
|
||||
|
||||
1032
libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py
Normal file
1032
libs/langchain_v1/langchain/agents/middleware/anthropic_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
588
libs/langchain_v1/langchain/agents/middleware/file_search.py
Normal file
588
libs/langchain_v1/langchain/agents/middleware/file_search.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""File search middleware for Anthropic text editor and memory tools.
|
||||
|
||||
This module provides Glob and Grep search tools that operate on files stored
|
||||
in state or filesystem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.tools import InjectedToolArg, tool
|
||||
|
||||
from langchain.agents.middleware.anthropic_tools import AnthropicToolsState
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
|
||||
|
||||
def _expand_include_patterns(pattern: str) -> list[str] | None:
|
||||
"""Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
|
||||
if "}" in pattern and "{" not in pattern:
|
||||
return None
|
||||
|
||||
expanded: list[str] = []
|
||||
|
||||
def _expand(current: str) -> None:
|
||||
start = current.find("{")
|
||||
if start == -1:
|
||||
expanded.append(current)
|
||||
return
|
||||
|
||||
end = current.find("}", start)
|
||||
if end == -1:
|
||||
raise ValueError
|
||||
|
||||
prefix = current[:start]
|
||||
suffix = current[end + 1 :]
|
||||
inner = current[start + 1 : end]
|
||||
if not inner:
|
||||
raise ValueError
|
||||
|
||||
for option in inner.split(","):
|
||||
_expand(prefix + option + suffix)
|
||||
|
||||
try:
|
||||
_expand(pattern)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
def _is_valid_include_pattern(pattern: str) -> bool:
|
||||
"""Validate glob pattern used for include filters."""
|
||||
if not pattern:
|
||||
return False
|
||||
|
||||
if any(char in pattern for char in ("\x00", "\n", "\r")):
|
||||
return False
|
||||
|
||||
expanded = _expand_include_patterns(pattern)
|
||||
if expanded is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
for candidate in expanded:
|
||||
re.compile(fnmatch.translate(candidate))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _match_include_pattern(basename: str, pattern: str) -> bool:
|
||||
"""Return True if the basename matches the include pattern."""
|
||||
expanded = _expand_include_patterns(pattern)
|
||||
if not expanded:
|
||||
return False
|
||||
|
||||
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)
|
||||
|
||||
|
||||
class StateFileSearchMiddleware(AgentMiddleware):
|
||||
"""Provides Glob and Grep search over state-based files.
|
||||
|
||||
This middleware adds two tools that search through virtual files in state:
|
||||
- Glob: Fast file pattern matching by file path
|
||||
- Grep: Fast content search using regular expressions
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
StateTextEditorToolMiddleware,
|
||||
StateFileSearchMiddleware,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[],
|
||||
middleware=[
|
||||
StateTextEditorToolMiddleware(),
|
||||
StateFileSearchMiddleware(),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
state_schema = AnthropicToolsState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state_key: str = "text_editor_files",
|
||||
) -> None:
|
||||
"""Initialize the search middleware.
|
||||
|
||||
Args:
|
||||
state_key: State key to search (default: "text_editor_files").
|
||||
Use "memory_files" to search memory tool files.
|
||||
"""
|
||||
self.state_key = state_key
|
||||
|
||||
# Create tool instances
|
||||
@tool
|
||||
def glob_search( # noqa: D417
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
Supports glob patterns like **/*.js or src/**/*.ts.
|
||||
Returns matching file paths sorted by modification time.
|
||||
Use this tool when you need to find files by name patterns.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Get files from state
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
|
||||
# Match files
|
||||
matches = []
|
||||
for file_path, file_data in files.items():
|
||||
if file_path.startswith(base_path):
|
||||
# Get relative path from base
|
||||
if base_path == "/":
|
||||
relative = file_path[1:] # Remove leading /
|
||||
elif file_path == base_path:
|
||||
relative = Path(file_path).name
|
||||
elif file_path.startswith(base_path + "/"):
|
||||
relative = file_path[len(base_path) + 1 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Match against pattern
|
||||
# Handle ** pattern which requires special care
|
||||
# PurePosixPath.match doesn't match single-level paths against **/pattern
|
||||
is_match = PurePosixPath(relative).match(pattern)
|
||||
if not is_match and pattern.startswith("**/"):
|
||||
# Also try matching without the **/ prefix for files in base dir
|
||||
is_match = PurePosixPath(relative).match(pattern[3:])
|
||||
|
||||
if is_match:
|
||||
matches.append((file_path, file_data["modified_at"]))
|
||||
|
||||
if not matches:
|
||||
return "No files found"
|
||||
|
||||
# Sort by modification time
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
file_paths = [path for path, _ in matches]
|
||||
|
||||
return "\n".join(file_paths)
|
||||
|
||||
@tool
|
||||
def grep_search( # noqa: D417
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
Searches file contents using regular expressions. Supports full regex
|
||||
syntax and filters files by pattern with the include parameter.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format:
|
||||
- "files_with_matches": Only file paths containing matches (default)
|
||||
- "content": Matching lines with file:line:content format
|
||||
- "count": Count of matches per file
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Search files
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
for file_path, file_data in files.items():
|
||||
if not file_path.startswith(base_path):
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include:
|
||||
basename = Path(file_path).name
|
||||
if not _match_include_pattern(basename, include):
|
||||
continue
|
||||
|
||||
# Search file content
|
||||
for line_num, line in enumerate(file_data["content"], 1):
|
||||
if regex.search(line):
|
||||
if file_path not in results:
|
||||
results[file_path] = []
|
||||
results[file_path].append((line_num, line))
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _format_grep_results(
|
||||
self,
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
output_mode: str,
|
||||
) -> str:
|
||||
"""Format grep results based on output mode."""
|
||||
if output_mode == "files_with_matches":
|
||||
# Just return file paths
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
if output_mode == "content":
|
||||
# Return file:line:content format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
for line_num, line in results[file_path]:
|
||||
lines.append(f"{file_path}:{line_num}:{line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
if output_mode == "count":
|
||||
# Return file:count format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
count = len(results[file_path])
|
||||
lines.append(f"{file_path}:{count}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Default to files_with_matches
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
|
||||
class FilesystemFileSearchMiddleware(AgentMiddleware):
|
||||
"""Provides Glob and Grep search over filesystem files.
|
||||
|
||||
This middleware adds two tools that search through local filesystem:
|
||||
- Glob: Fast file pattern matching by file path
|
||||
- Grep: Fast content search using ripgrep or Python fallback
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
FilesystemTextEditorToolMiddleware,
|
||||
FilesystemFileSearchMiddleware,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[],
|
||||
middleware=[
|
||||
FilesystemTextEditorToolMiddleware(root_path="/workspace"),
|
||||
FilesystemFileSearchMiddleware(root_path="/workspace"),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
root_path: str,
|
||||
use_ripgrep: bool = True,
|
||||
max_file_size_mb: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the search middleware.
|
||||
|
||||
Args:
|
||||
root_path: Root directory to search.
|
||||
use_ripgrep: Whether to use ripgrep for search (default: True).
|
||||
Falls back to Python if ripgrep unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB (default: 10).
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
self.use_ripgrep = use_ripgrep
|
||||
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
# Create tool instances as closures that capture self
|
||||
@tool
|
||||
def glob_search(pattern: str, path: str = "/") -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
Supports glob patterns like **/*.js or src/**/*.ts.
|
||||
Returns matching file paths sorted by modification time.
|
||||
Use this tool when you need to find files by name patterns.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(path)
|
||||
except ValueError:
|
||||
return "No files found"
|
||||
|
||||
if not base_full.exists() or not base_full.is_dir():
|
||||
return "No files found"
|
||||
|
||||
# Use pathlib glob
|
||||
matching: list[tuple[str, str]] = []
|
||||
for match in base_full.glob(pattern):
|
||||
if match.is_file():
|
||||
# Convert to virtual path
|
||||
virtual_path = "/" + str(match.relative_to(self.root_path))
|
||||
stat = match.stat()
|
||||
modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat()
|
||||
matching.append((virtual_path, modified_at))
|
||||
|
||||
if not matching:
|
||||
return "No files found"
|
||||
|
||||
file_paths = [p for p, _ in matching]
|
||||
return "\n".join(file_paths)
|
||||
|
||||
@tool
|
||||
def grep_search(
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
Searches file contents using regular expressions. Supports full regex
|
||||
syntax and filters files by pattern with the include parameter.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format:
|
||||
- "files_with_matches": Only file paths containing matches (default)
|
||||
- "content": Matching lines with file:line:content format
|
||||
- "count": Count of matches per file
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Try ripgrep first if enabled
|
||||
results = None
|
||||
if self.use_ripgrep:
|
||||
with suppress(
|
||||
FileNotFoundError,
|
||||
subprocess.CalledProcessError,
|
||||
subprocess.TimeoutExpired,
|
||||
):
|
||||
results = self._ripgrep_search(pattern, path, include)
|
||||
|
||||
# Python fallback if ripgrep failed or is disabled
|
||||
if results is None:
|
||||
results = self._python_search(pattern, path, include)
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _validate_and_resolve_path(self, path: str) -> Path:
|
||||
"""Validate and resolve a virtual path to filesystem path."""
|
||||
# Normalize path
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for path traversal
|
||||
if ".." in path or "~" in path:
|
||||
msg = "Path traversal not allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert virtual path to filesystem path
|
||||
relative = path.lstrip("/")
|
||||
full_path = (self.root_path / relative).resolve()
|
||||
|
||||
# Ensure path is within root
|
||||
try:
|
||||
full_path.relative_to(self.root_path)
|
||||
except ValueError:
|
||||
msg = f"Path outside root directory: {path}"
|
||||
raise ValueError(msg) from None
|
||||
|
||||
return full_path
|
||||
|
||||
def _ripgrep_search(
|
||||
self, pattern: str, base_path: str, include: str | None
|
||||
) -> dict[str, list[tuple[int, str]]]:
|
||||
"""Search using ripgrep subprocess."""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(base_path)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
if not base_full.exists():
|
||||
return {}
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = ["rg", "--json"]
|
||||
|
||||
if include:
|
||||
# Convert glob pattern to ripgrep glob
|
||||
cmd.extend(["--glob", include])
|
||||
|
||||
cmd.extend(["--", pattern, str(base_full)])
|
||||
|
||||
try:
|
||||
result = subprocess.run( # noqa: S603
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
# Fallback to Python search if ripgrep unavailable or times out
|
||||
return self._python_search(pattern, base_path, include)
|
||||
|
||||
# Parse ripgrep JSON output
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
for line in result.stdout.splitlines():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if data["type"] == "match":
|
||||
path = data["data"]["path"]["text"]
|
||||
# Convert to virtual path
|
||||
virtual_path = "/" + str(Path(path).relative_to(self.root_path))
|
||||
line_num = data["data"]["line_number"]
|
||||
line_text = data["data"]["lines"]["text"].rstrip("\n")
|
||||
|
||||
if virtual_path not in results:
|
||||
results[virtual_path] = []
|
||||
results[virtual_path].append((line_num, line_text))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _python_search(
|
||||
self, pattern: str, base_path: str, include: str | None
|
||||
) -> dict[str, list[tuple[int, str]]]:
|
||||
"""Search using Python regex (fallback)."""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(base_path)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
if not base_full.exists():
|
||||
return {}
|
||||
|
||||
regex = re.compile(pattern)
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
# Walk directory tree
|
||||
for file_path in base_full.rglob("*"):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include and not _match_include_pattern(file_path.name, include):
|
||||
continue
|
||||
|
||||
# Skip files that are too large
|
||||
if file_path.stat().st_size > self.max_file_size_bytes:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text()
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
continue
|
||||
|
||||
# Search content
|
||||
for line_num, line in enumerate(content.splitlines(), 1):
|
||||
if regex.search(line):
|
||||
virtual_path = "/" + str(file_path.relative_to(self.root_path))
|
||||
if virtual_path not in results:
|
||||
results[virtual_path] = []
|
||||
results[virtual_path].append((line_num, line))
|
||||
|
||||
return results
|
||||
|
||||
def _format_grep_results(
|
||||
self,
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
output_mode: str,
|
||||
) -> str:
|
||||
"""Format grep results based on output mode."""
|
||||
if output_mode == "files_with_matches":
|
||||
# Just return file paths
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
if output_mode == "content":
|
||||
# Return file:line:content format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
for line_num, line in results[file_path]:
|
||||
lines.append(f"{file_path}:{line_num}:{line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
if output_mode == "count":
|
||||
# Return file:count format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
count = len(results[file_path])
|
||||
lines.append(f"{file_path}:{count}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Default to files_with_matches
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"StateFileSearchMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,276 @@
|
||||
"""Unit tests for Anthropic text editor and memory tool middleware."""
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.anthropic_tools import (
|
||||
AnthropicToolsState,
|
||||
StateClaudeMemoryMiddleware,
|
||||
StateClaudeTextEditorMiddleware,
|
||||
_validate_path,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
class TestPathValidation:
|
||||
"""Test path validation and security."""
|
||||
|
||||
def test_basic_path_normalization(self) -> None:
|
||||
"""Test basic path normalization."""
|
||||
assert _validate_path("/foo/bar") == "/foo/bar"
|
||||
assert _validate_path("foo/bar") == "/foo/bar"
|
||||
assert _validate_path("/foo//bar") == "/foo/bar"
|
||||
assert _validate_path("/foo/./bar") == "/foo/bar"
|
||||
|
||||
def test_path_traversal_blocked(self) -> None:
|
||||
"""Test that path traversal attempts are blocked."""
|
||||
with pytest.raises(ValueError, match="Path traversal not allowed"):
|
||||
_validate_path("/foo/../etc/passwd")
|
||||
|
||||
with pytest.raises(ValueError, match="Path traversal not allowed"):
|
||||
_validate_path("../etc/passwd")
|
||||
|
||||
with pytest.raises(ValueError, match="Path traversal not allowed"):
|
||||
_validate_path("~/.ssh/id_rsa")
|
||||
|
||||
def test_allowed_prefixes(self) -> None:
|
||||
"""Test path prefix validation."""
|
||||
# Should pass
|
||||
assert (
|
||||
_validate_path("/workspace/file.txt", allowed_prefixes=["/workspace"])
|
||||
== "/workspace/file.txt"
|
||||
)
|
||||
|
||||
# Should fail
|
||||
with pytest.raises(ValueError, match="Path must start with"):
|
||||
_validate_path("/etc/passwd", allowed_prefixes=["/workspace"])
|
||||
|
||||
with pytest.raises(ValueError, match="Path must start with"):
|
||||
_validate_path("/workspacemalicious/file.txt", allowed_prefixes=["/workspace/"])
|
||||
|
||||
def test_memories_prefix(self) -> None:
|
||||
"""Test /memories prefix validation for memory tools."""
|
||||
assert (
|
||||
_validate_path("/memories/notes.txt", allowed_prefixes=["/memories"])
|
||||
== "/memories/notes.txt"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Path must start with"):
|
||||
_validate_path("/other/notes.txt", allowed_prefixes=["/memories"])
|
||||
|
||||
|
||||
class TestTextEditorMiddleware:
|
||||
"""Test text editor middleware functionality."""
|
||||
|
||||
def test_middleware_initialization(self) -> None:
|
||||
"""Test middleware initializes correctly."""
|
||||
middleware = StateClaudeTextEditorMiddleware()
|
||||
assert middleware.state_schema == AnthropicToolsState
|
||||
assert middleware.tool_type == "text_editor_20250728"
|
||||
assert middleware.tool_name == "str_replace_based_edit_tool"
|
||||
assert middleware.state_key == "text_editor_files"
|
||||
|
||||
# With path restrictions
|
||||
middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"])
|
||||
assert middleware.allowed_prefixes == ["/workspace"]
|
||||
|
||||
|
||||
class TestMemoryMiddleware:
|
||||
"""Test memory middleware functionality."""
|
||||
|
||||
def test_middleware_initialization(self) -> None:
|
||||
"""Test middleware initializes correctly."""
|
||||
middleware = StateClaudeMemoryMiddleware()
|
||||
assert middleware.state_schema == AnthropicToolsState
|
||||
assert middleware.tool_type == "memory_20250818"
|
||||
assert middleware.tool_name == "memory"
|
||||
assert middleware.state_key == "memory_files"
|
||||
assert middleware.system_prompt # Should have default prompt
|
||||
|
||||
def test_custom_system_prompt(self) -> None:
|
||||
"""Test custom system prompt can be set."""
|
||||
custom_prompt = "Custom memory instructions"
|
||||
middleware = StateClaudeMemoryMiddleware(system_prompt=custom_prompt)
|
||||
assert middleware.system_prompt == custom_prompt
|
||||
|
||||
|
||||
class TestFileOperations:
|
||||
"""Test file operation implementations via wrap_tool_call."""
|
||||
|
||||
def test_view_operation(self) -> None:
|
||||
"""Test view command execution."""
|
||||
middleware = StateClaudeTextEditorMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/test.txt": {
|
||||
"content": ["line1", "line2", "line3"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
args = {"command": "view", "path": "/test.txt"}
|
||||
result = middleware._handle_view(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert result.update is not None
|
||||
messages = result.update.get("messages", [])
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0], ToolMessage)
|
||||
assert messages[0].content == "1|line1\n2|line2\n3|line3"
|
||||
assert messages[0].tool_call_id == "test_id"
|
||||
|
||||
def test_create_operation(self) -> None:
|
||||
"""Test create command execution."""
|
||||
middleware = StateClaudeTextEditorMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {"messages": []}
|
||||
|
||||
args = {"command": "create", "path": "/test.txt", "file_text": "line1\nline2"}
|
||||
result = middleware._handle_create(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
assert result.update is not None
|
||||
files = result.update.get("text_editor_files", {})
|
||||
assert "/test.txt" in files
|
||||
assert files["/test.txt"]["content"] == ["line1", "line2"]
|
||||
|
||||
def test_path_prefix_enforcement(self) -> None:
|
||||
"""Test that path prefixes are enforced."""
|
||||
middleware = StateClaudeTextEditorMiddleware(allowed_path_prefixes=["/workspace"])
|
||||
|
||||
state: AnthropicToolsState = {"messages": []}
|
||||
|
||||
# Should fail with /etc/passwd
|
||||
args = {"command": "create", "path": "/etc/passwd", "file_text": "test"}
|
||||
|
||||
try:
|
||||
middleware._handle_create(args, state, "test_id")
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "Path must start with" in str(e)
|
||||
|
||||
def test_memories_prefix_enforcement(self) -> None:
|
||||
"""Test that /memories prefix is enforced for memory middleware."""
|
||||
middleware = StateClaudeMemoryMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {"messages": []}
|
||||
|
||||
# Should fail with /other/path
|
||||
args = {"command": "create", "path": "/other/path.txt", "file_text": "test"}
|
||||
|
||||
try:
|
||||
middleware._handle_create(args, state, "test_id")
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "/memories" in str(e)
|
||||
|
||||
def test_str_replace_operation(self) -> None:
|
||||
"""Test str_replace command execution."""
|
||||
middleware = StateClaudeTextEditorMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/test.txt": {
|
||||
"content": ["Hello world", "Goodbye world"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
args = {
|
||||
"command": "str_replace",
|
||||
"path": "/test.txt",
|
||||
"old_str": "world",
|
||||
"new_str": "universe",
|
||||
}
|
||||
result = middleware._handle_str_replace(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
files = result.update.get("text_editor_files", {})
|
||||
# Should only replace first occurrence
|
||||
assert files["/test.txt"]["content"] == ["Hello universe", "Goodbye world"]
|
||||
|
||||
def test_insert_operation(self) -> None:
|
||||
"""Test insert command execution."""
|
||||
middleware = StateClaudeTextEditorMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/test.txt": {
|
||||
"content": ["line1", "line2"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
args = {
|
||||
"command": "insert",
|
||||
"path": "/test.txt",
|
||||
"insert_line": 0,
|
||||
"new_str": "inserted",
|
||||
}
|
||||
result = middleware._handle_insert(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
files = result.update.get("text_editor_files", {})
|
||||
assert files["/test.txt"]["content"] == ["inserted", "line1", "line2"]
|
||||
|
||||
def test_delete_operation(self) -> None:
|
||||
"""Test delete command execution (memory only)."""
|
||||
middleware = StateClaudeMemoryMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"memory_files": {
|
||||
"/memories/test.txt": {
|
||||
"content": ["line1"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
args = {"command": "delete", "path": "/memories/test.txt"}
|
||||
result = middleware._handle_delete(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
files = result.update.get("memory_files", {})
|
||||
# Deleted files are marked as None in state
|
||||
assert files.get("/memories/test.txt") is None
|
||||
|
||||
def test_rename_operation(self) -> None:
|
||||
"""Test rename command execution (memory only)."""
|
||||
middleware = StateClaudeMemoryMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"memory_files": {
|
||||
"/memories/old.txt": {
|
||||
"content": ["line1"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
args = {
|
||||
"command": "rename",
|
||||
"old_path": "/memories/old.txt",
|
||||
"new_path": "/memories/new.txt",
|
||||
}
|
||||
result = middleware._handle_rename(args, state, "test_id")
|
||||
|
||||
assert isinstance(result, Command)
|
||||
files = result.update.get("memory_files", {})
|
||||
# Old path is marked as None (deleted)
|
||||
assert files.get("/memories/old.txt") is None
|
||||
# New path has the file data
|
||||
assert files.get("/memories/new.txt") is not None
|
||||
assert files["/memories/new.txt"]["content"] == ["line1"]
|
||||
@@ -0,0 +1,530 @@
|
||||
"""Unit tests for file search middleware."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.anthropic_tools import AnthropicToolsState
|
||||
from langchain.agents.middleware.file_search import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
StateFileSearchMiddleware,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
|
||||
class TestSearchMiddlewareInitialization:
|
||||
"""Test search middleware initialization."""
|
||||
|
||||
def test_middleware_initialization(self) -> None:
|
||||
"""Test middleware initializes correctly."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
assert middleware.state_schema == AnthropicToolsState
|
||||
assert middleware.state_key == "text_editor_files"
|
||||
|
||||
def test_custom_state_key(self) -> None:
|
||||
"""Test middleware with custom state key."""
|
||||
middleware = StateFileSearchMiddleware(state_key="memory_files")
|
||||
assert middleware.state_key == "memory_files"
|
||||
|
||||
|
||||
class TestGlobSearch:
|
||||
"""Test Glob file pattern matching."""
|
||||
|
||||
def test_glob_basic_pattern(self) -> None:
|
||||
"""Test basic glob pattern matching."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
test_state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["print('hello')"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/utils.py": {
|
||||
"content": ["def helper(): pass"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/README.md": {
|
||||
"content": ["# Project"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Call tool function directly (state is injected in real usage)
|
||||
result = middleware.glob_search.func(pattern="*.py", state=test_state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/src/utils.py" in result
|
||||
assert "/README.md" not in result
|
||||
|
||||
def test_glob_recursive_pattern(self) -> None:
|
||||
"""Test recursive glob pattern matching."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/utils/helper.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/tests/test_main.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*.py", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
assert len(lines) == 3
|
||||
assert all(".py" in line for line in lines)
|
||||
|
||||
def test_glob_with_base_path(self) -> None:
|
||||
"""Test glob with base path restriction."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/tests/test.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*.py", path="/src", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/tests/test.py" not in result
|
||||
|
||||
def test_glob_no_matches(self) -> None:
|
||||
"""Test glob with no matching files."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.ts", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No files found"
|
||||
|
||||
def test_glob_sorts_by_modified_time(self) -> None:
|
||||
"""Test that glob results are sorted by modification time."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/old.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/new.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-02T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.py", state=state)
|
||||
|
||||
lines = result.split("\n")
|
||||
# Most recent first
|
||||
assert lines[0] == "/new.py"
|
||||
assert lines[1] == "/old.py"
|
||||
|
||||
|
||||
class TestGrepSearch:
|
||||
"""Test Grep content search."""
|
||||
|
||||
def test_grep_files_with_matches_mode(self) -> None:
|
||||
"""Test grep with files_with_matches output mode."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["def foo():", " pass"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/utils.py": {
|
||||
"content": ["def bar():", " return None"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/README.md": {
|
||||
"content": ["# Documentation", "No code here"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/src/utils.py" in result
|
||||
assert "/README.md" not in result
|
||||
# Should only have file paths, not line content
|
||||
|
||||
def test_grep_invalid_include_pattern(self) -> None:
|
||||
"""Return error when include glob is invalid."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["def foo():"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"def", include="*.{py", state=state)
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
|
||||
|
||||
class TestFilesystemGrepSearch:
|
||||
"""Tests for filesystem-backed grep search."""
|
||||
|
||||
def test_grep_invalid_include_pattern(self, tmp_path: Path) -> None:
|
||||
"""Return error when include glob cannot be parsed."""
|
||||
|
||||
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
|
||||
|
||||
result = middleware.grep_search.func(pattern="print", include="*.{py")
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
|
||||
def test_ripgrep_command_uses_literal_pattern(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Ensure ripgrep receives pattern after ``--`` to avoid option parsing."""
|
||||
|
||||
(tmp_path / "example.py").write_text("print('hello')\n", encoding="utf-8")
|
||||
|
||||
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=True)
|
||||
|
||||
captured: dict[str, list[str]] = {}
|
||||
|
||||
class DummyResult:
|
||||
stdout = ""
|
||||
|
||||
def fake_run(*args: Any, **kwargs: Any) -> DummyResult:
|
||||
cmd = args[0]
|
||||
captured["cmd"] = cmd
|
||||
return DummyResult()
|
||||
|
||||
monkeypatch.setattr("langchain.agents.middleware.file_search.subprocess.run", fake_run)
|
||||
|
||||
middleware._ripgrep_search("--pattern", "/", None)
|
||||
|
||||
assert "cmd" in captured
|
||||
cmd = captured["cmd"]
|
||||
assert cmd[:2] == ["rg", "--json"]
|
||||
assert "--" in cmd
|
||||
separator_index = cmd.index("--")
|
||||
assert cmd[separator_index + 1] == "--pattern"
|
||||
|
||||
def test_grep_content_mode(self) -> None:
|
||||
"""Test grep with content output mode."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["def foo():", " pass", "def bar():"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(
|
||||
pattern=r"def \w+\(\):", output_mode="content", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
assert len(lines) == 2
|
||||
assert lines[0] == "/src/main.py:1:def foo():"
|
||||
assert lines[1] == "/src/main.py:3:def bar():"
|
||||
|
||||
def test_grep_count_mode(self) -> None:
|
||||
"""Test grep with count output mode."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["TODO: fix this", "print('hello')", "TODO: add tests"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/utils.py": {
|
||||
"content": ["TODO: implement"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", output_mode="count", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
assert "/src/main.py:2" in lines
|
||||
assert "/src/utils.py:1" in lines
|
||||
|
||||
def test_grep_with_include_filter(self) -> None:
|
||||
"""Test grep with include file pattern filter."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["import os"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/main.ts": {
|
||||
"content": ["import os from 'os'"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="import", include="*.py", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/src/main.ts" not in result
|
||||
|
||||
def test_grep_with_brace_expansion_filter(self) -> None:
|
||||
"""Test grep with brace expansion in include filter."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.ts": {
|
||||
"content": ["const x = 1"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/App.tsx": {
|
||||
"content": ["const y = 2"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/src/main.py": {
|
||||
"content": ["z = 3"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="const", include="*.{ts,tsx}", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.ts" in result
|
||||
assert "/src/App.tsx" in result
|
||||
assert "/src/main.py" not in result
|
||||
|
||||
def test_grep_with_base_path(self) -> None:
|
||||
"""Test grep with base path restriction."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["import foo"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
"/tests/test.py": {
|
||||
"content": ["import foo"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="import", path="/src", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/tests/test.py" not in result
|
||||
|
||||
def test_grep_no_matches(self) -> None:
|
||||
"""Test grep with no matching content."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["print('hello')"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No matches found"
|
||||
|
||||
def test_grep_invalid_regex(self) -> None:
|
||||
"""Test grep with invalid regex pattern."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"[unclosed", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
|
||||
class TestSearchWithDifferentBackends:
|
||||
"""Test searching with different backend configurations."""
|
||||
|
||||
def test_glob_default_backend(self) -> None:
|
||||
"""Test that glob searches the default backend (text_editor_files)."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
"memory_files": {
|
||||
"/memories/notes.txt": {
|
||||
"content": [],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
# Should NOT find memory_files since default backend is text_editor_files
|
||||
assert "/memories/notes.txt" not in result
|
||||
|
||||
def test_grep_default_backend(self) -> None:
|
||||
"""Test that grep searches the default backend (text_editor_files)."""
|
||||
middleware = StateFileSearchMiddleware()
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["TODO: implement"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
"memory_files": {
|
||||
"/memories/tasks.txt": {
|
||||
"content": ["TODO: review"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
# Should NOT find memory_files since default backend is text_editor_files
|
||||
assert "/memories/tasks.txt" not in result
|
||||
|
||||
def test_search_with_single_store(self) -> None:
|
||||
"""Test searching with a specific state key."""
|
||||
middleware = StateFileSearchMiddleware(state_key="text_editor_files")
|
||||
|
||||
state: AnthropicToolsState = {
|
||||
"messages": [],
|
||||
"text_editor_files": {
|
||||
"/src/main.py": {
|
||||
"content": ["code"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
"memory_files": {
|
||||
"/memories/notes.txt": {
|
||||
"content": ["notes"],
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"modified_at": "2025-01-01T00:00:00",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r".*", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
assert "/memories/notes.txt" not in result
|
||||
Reference in New Issue
Block a user