fix: adding missing async hooks (#33957)

* filling in missing async gaps
* using recommended tool runtime injection instead of injected state
  * updating tests to use helper function as well
This commit is contained in:
Sydney Runkle
2025-11-14 09:13:39 -05:00
committed by GitHub
parent 26d39ffc4a
commit 83c078f363
8 changed files with 331 additions and 108 deletions

View File

@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)

View File

@@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment model call counts after a model call.
@@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
def after_model(
self,
state: AgentState,
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",

View File

@@ -482,7 +482,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
@@ -494,7 +494,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
@@ -689,7 +689,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
"""Async intercept local shell tool calls and execute them via the managed session."""
# The sync version already handles all the work, no need for async-specific logic
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(

View File

@@ -451,3 +451,28 @@ class ToolCallLimitMiddleware(
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -40,6 +40,18 @@ class ClaudeBashToolMiddleware(ShellToolMiddleware):
request = request.override(tools=tools)
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Async: ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,

View File

@@ -9,13 +9,13 @@ from __future__ import annotations
import fnmatch
import re
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Annotated, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
if TYPE_CHECKING:
from typing import Any
from langchain.agents.middleware.types import AgentMiddleware
from langchain_core.tools import InjectedToolArg, tool
from langchain.tools import ToolRuntime, tool
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
# Create tool instances
@tool
def glob_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast file pattern matching tool that works with any codebase size.
@@ -147,56 +147,17 @@ class StateFileSearchMiddleware(AgentMiddleware):
time (most recently modified first). Returns "No files found" if no
matches.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Get files from state
files = cast("dict[str, Any]", state.get(self.state_key, {}))
# Match files
matches = []
for file_path, file_data in files.items():
if file_path.startswith(base_path):
# Get relative path from base
if base_path == "/":
relative = file_path[1:] # Remove leading /
elif file_path == base_path:
relative = Path(file_path).name
elif file_path.startswith(base_path + "/"):
relative = file_path[len(base_path) + 1 :]
else:
continue
# Match against pattern
# Handle ** pattern which requires special care
# PurePosixPath.match doesn't match single-level paths
# against **/pattern
is_match = PurePosixPath(relative).match(pattern)
if not is_match and pattern.startswith("**/"):
# Also try matching without the **/ prefix for files in base dir
is_match = PurePosixPath(relative).match(pattern[3:])
if is_match:
matches.append((file_path, file_data["modified_at"]))
if not matches:
return "No files found"
# Sort by modification time
matches.sort(key=lambda x: x[1], reverse=True)
file_paths = [path for path, _ in matches]
return "\n".join(file_paths)
return self._handle_glob_search(pattern, path, runtime.state)
@tool
def grep_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal[
"files_with_matches", "content", "count"
] = "files_with_matches",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast content search tool that works with any codebase size.
@@ -216,49 +177,133 @@ class StateFileSearchMiddleware(AgentMiddleware):
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Compile regex pattern (for validation)
try:
regex = re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Search files
files = cast("dict[str, Any]", state.get(self.state_key, {}))
results: dict[str, list[tuple[int, str]]] = {}
for file_path, file_data in files.items():
if not file_path.startswith(base_path):
continue
# Check include filter
if include:
basename = Path(file_path).name
if not _match_include_pattern(basename, include):
continue
# Search file content
for line_num, line in enumerate(file_data["content"], 1):
if regex.search(line):
if file_path not in results:
results[file_path] = []
results[file_path].append((line_num, line))
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
return self._handle_grep_search(
pattern, path, include, output_mode, runtime.state
)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _handle_glob_search(
self,
pattern: str,
path: str,
state: AnthropicToolsState,
) -> str:
"""Handle glob search operation.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in.
state: The current agent state.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
matches.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Get files from state
files = cast("dict[str, Any]", state.get(self.state_key, {}))
# Match files
matches = []
for file_path, file_data in files.items():
if file_path.startswith(base_path):
# Get relative path from base
if base_path == "/":
relative = file_path[1:] # Remove leading /
elif file_path == base_path:
relative = Path(file_path).name
elif file_path.startswith(base_path + "/"):
relative = file_path[len(base_path) + 1 :]
else:
continue
# Match against pattern
# Handle ** pattern which requires special care
# PurePosixPath.match doesn't match single-level paths
# against **/pattern
is_match = PurePosixPath(relative).match(pattern)
if not is_match and pattern.startswith("**/"):
# Also try matching without the **/ prefix for files in base dir
is_match = PurePosixPath(relative).match(pattern[3:])
if is_match:
matches.append((file_path, file_data["modified_at"]))
if not matches:
return "No files found"
# Sort by modification time
matches.sort(key=lambda x: x[1], reverse=True)
file_paths = [path for path, _ in matches]
return "\n".join(file_paths)
def _handle_grep_search(
self,
pattern: str,
path: str,
include: str | None,
output_mode: str,
state: AnthropicToolsState,
) -> str:
"""Handle grep search operation.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format.
state: The current agent state.
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Compile regex pattern (for validation)
try:
regex = re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Search files
files = cast("dict[str, Any]", state.get(self.state_key, {}))
results: dict[str, list[tuple[int, str]]] = {}
for file_path, file_data in files.items():
if not file_path.startswith(base_path):
continue
# Check include filter
if include:
basename = Path(file_path).name
if not _match_include_pattern(basename, include):
continue
# Search file content
for line_num, line in enumerate(file_data["content"], 1):
if regex.search(line):
if file_path not in results:
results[file_path] = []
results[file_path].append((line_num, line))
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
def _format_grep_results(
self,
results: dict[str, list[tuple[int, str]]],

View File

@@ -49,8 +49,10 @@ class TestGlobSearch:
},
}
# Call tool function directly (state is injected in real usage)
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
# Call internal handler method directly
result = middleware._handle_glob_search(
pattern="*.py", path="/", state=test_state
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -82,7 +84,9 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/", state=state
)
assert isinstance(result, str)
lines = result.split("\n")
@@ -109,7 +113,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func( # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/src", state=state
)
@@ -132,7 +136,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
assert isinstance(result, str)
assert result == "No files found"
@@ -157,7 +161,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
lines = result.split("\n")
# Most recent first
@@ -193,7 +197,13 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -216,8 +226,12 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def", include="*.{py", state=state
result = middleware._handle_grep_search(
pattern=r"def",
path="/",
include="*.{py",
output_mode="files_with_matches",
state=state,
)
assert result == "Invalid include pattern"
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def \w+\(\):", output_mode="content", state=state
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="content",
state=state,
)
assert isinstance(result, str)
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"TODO", output_mode="count", state=state
result = middleware._handle_grep_search(
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
)
assert isinstance(result, str)
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="import", include="*.py", state=state
result = middleware._handle_grep_search(
pattern="import",
path="/",
include="*.py",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="const", include="*.{ts,tsx}", state=state
result = middleware._handle_grep_search(
pattern="const",
path="/",
include="*.{ts,tsx}",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern="import",
path="/src",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert result == "No matches found"
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
"text_editor_files": {},
}
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"[unclosed",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "Invalid regex pattern" in result
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r".*",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result