mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix: adding missing async hooks (#33957)
* filling in missing async gaps * using recommended tool runtime injection instead of injected state * updating tests to use helper function as well
This commit is contained in:
@@ -40,6 +40,18 @@ class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||
request = request.override(tools=tools)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Async: ensure the Claude bash descriptor is available to the model."""
|
||||
tools = request.tools
|
||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||
request = request.override(tools=tools)
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
|
||||
@@ -9,13 +9,13 @@ from __future__ import annotations
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import TYPE_CHECKING, Annotated, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
from langchain_core.tools import InjectedToolArg, tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
|
||||
|
||||
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
# Create tool instances
|
||||
@tool
|
||||
def glob_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
@@ -147,56 +147,17 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Get files from state
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
|
||||
# Match files
|
||||
matches = []
|
||||
for file_path, file_data in files.items():
|
||||
if file_path.startswith(base_path):
|
||||
# Get relative path from base
|
||||
if base_path == "/":
|
||||
relative = file_path[1:] # Remove leading /
|
||||
elif file_path == base_path:
|
||||
relative = Path(file_path).name
|
||||
elif file_path.startswith(base_path + "/"):
|
||||
relative = file_path[len(base_path) + 1 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Match against pattern
|
||||
# Handle ** pattern which requires special care
|
||||
# PurePosixPath.match doesn't match single-level paths
|
||||
# against **/pattern
|
||||
is_match = PurePosixPath(relative).match(pattern)
|
||||
if not is_match and pattern.startswith("**/"):
|
||||
# Also try matching without the **/ prefix for files in base dir
|
||||
is_match = PurePosixPath(relative).match(pattern[3:])
|
||||
|
||||
if is_match:
|
||||
matches.append((file_path, file_data["modified_at"]))
|
||||
|
||||
if not matches:
|
||||
return "No files found"
|
||||
|
||||
# Sort by modification time
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
file_paths = [path for path, _ in matches]
|
||||
|
||||
return "\n".join(file_paths)
|
||||
return self._handle_glob_search(pattern, path, runtime.state)
|
||||
|
||||
@tool
|
||||
def grep_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal[
|
||||
"files_with_matches", "content", "count"
|
||||
] = "files_with_matches",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
@@ -216,49 +177,133 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Search files
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
for file_path, file_data in files.items():
|
||||
if not file_path.startswith(base_path):
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include:
|
||||
basename = Path(file_path).name
|
||||
if not _match_include_pattern(basename, include):
|
||||
continue
|
||||
|
||||
# Search file content
|
||||
for line_num, line in enumerate(file_data["content"], 1):
|
||||
if regex.search(line):
|
||||
if file_path not in results:
|
||||
results[file_path] = []
|
||||
results[file_path].append((line_num, line))
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
return self._handle_grep_search(
|
||||
pattern, path, include, output_mode, runtime.state
|
||||
)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _handle_glob_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Handle glob search operation.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Get files from state
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
|
||||
# Match files
|
||||
matches = []
|
||||
for file_path, file_data in files.items():
|
||||
if file_path.startswith(base_path):
|
||||
# Get relative path from base
|
||||
if base_path == "/":
|
||||
relative = file_path[1:] # Remove leading /
|
||||
elif file_path == base_path:
|
||||
relative = Path(file_path).name
|
||||
elif file_path.startswith(base_path + "/"):
|
||||
relative = file_path[len(base_path) + 1 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Match against pattern
|
||||
# Handle ** pattern which requires special care
|
||||
# PurePosixPath.match doesn't match single-level paths
|
||||
# against **/pattern
|
||||
is_match = PurePosixPath(relative).match(pattern)
|
||||
if not is_match and pattern.startswith("**/"):
|
||||
# Also try matching without the **/ prefix for files in base dir
|
||||
is_match = PurePosixPath(relative).match(pattern[3:])
|
||||
|
||||
if is_match:
|
||||
matches.append((file_path, file_data["modified_at"]))
|
||||
|
||||
if not matches:
|
||||
return "No files found"
|
||||
|
||||
# Sort by modification time
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
file_paths = [path for path, _ in matches]
|
||||
|
||||
return "\n".join(file_paths)
|
||||
|
||||
def _handle_grep_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include: str | None,
|
||||
output_mode: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Handle grep search operation.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Search files
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
for file_path, file_data in files.items():
|
||||
if not file_path.startswith(base_path):
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include:
|
||||
basename = Path(file_path).name
|
||||
if not _match_include_pattern(basename, include):
|
||||
continue
|
||||
|
||||
# Search file content
|
||||
for line_num, line in enumerate(file_data["content"], 1):
|
||||
if regex.search(line):
|
||||
if file_path not in results:
|
||||
results[file_path] = []
|
||||
results[file_path].append((line_num, line))
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
def _format_grep_results(
|
||||
self,
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
|
||||
@@ -49,8 +49,10 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
# Call tool function directly (state is injected in real usage)
|
||||
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
|
||||
# Call internal handler method directly
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="*.py", path="/", state=test_state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -82,7 +84,9 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
@@ -109,7 +113,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func( # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/src", state=state
|
||||
)
|
||||
|
||||
@@ -132,7 +136,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No files found"
|
||||
@@ -157,7 +161,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
|
||||
|
||||
lines = result.split("\n")
|
||||
# Most recent first
|
||||
@@ -193,7 +197,13 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -216,8 +226,12 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def", include="*.{py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def",
|
||||
path="/",
|
||||
include="*.{py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def \w+\(\):", output_mode="content", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="content",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"TODO", output_mode="count", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="import", include="*.py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/",
|
||||
include="*.py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="const", include="*.{ts,tsx}", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="const",
|
||||
path="/",
|
||||
include="*.{ts,tsx}",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/src",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No matches found"
|
||||
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
|
||||
"text_editor_files": {},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"[unclosed",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Invalid regex pattern" in result
|
||||
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r".*",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
|
||||
Reference in New Issue
Block a user