fix: adding missing async hooks (#33957)

* filling in missing async gaps
* using recommended tool runtime injection instead of injected state
  * updating tests to use helper function as well
This commit is contained in:
Sydney Runkle
2025-11-14 09:13:39 -05:00
committed by GitHub
parent 26d39ffc4a
commit 83c078f363
8 changed files with 331 additions and 108 deletions

View File

@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)

View File

@@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment model call counts after a model call.
@@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
def after_model(
self,
state: AgentState,
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",

View File

@@ -482,7 +482,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
@@ -494,7 +494,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
@@ -689,7 +689,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
"""Async intercept local shell tool calls and execute them via the managed session."""
# The sync version already handles all the work, no need for async-specific logic
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(

View File

@@ -451,3 +451,28 @@ class ToolCallLimitMiddleware(
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -40,6 +40,18 @@ class ClaudeBashToolMiddleware(ShellToolMiddleware):
request = request.override(tools=tools)
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Async: ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,

View File

@@ -9,13 +9,13 @@ from __future__ import annotations
import fnmatch
import re
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Annotated, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
if TYPE_CHECKING:
from typing import Any
from langchain.agents.middleware.types import AgentMiddleware
from langchain_core.tools import InjectedToolArg, tool
from langchain.tools import ToolRuntime, tool
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
# Create tool instances
@tool
def glob_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast file pattern matching tool that works with any codebase size.
@@ -142,6 +142,62 @@ class StateFileSearchMiddleware(AgentMiddleware):
pattern: The glob pattern to match files against.
path: The directory to search in. If not specified, searches from root.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
matches.
"""
return self._handle_glob_search(pattern, path, runtime.state)
@tool
def grep_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal[
"files_with_matches", "content", "count"
] = "files_with_matches",
) -> str:
"""Fast content search tool that works with any codebase size.
Searches file contents using regular expressions. Supports full regex
syntax and filters files by pattern with the include parameter.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format:
- "files_with_matches": Only file paths containing matches (default)
- "content": Matching lines with file:line:content format
- "count": Count of matches per file
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
return self._handle_grep_search(
pattern, path, include, output_mode, runtime.state
)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _handle_glob_search(
self,
pattern: str,
path: str,
state: AnthropicToolsState,
) -> str:
"""Handle glob search operation.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in.
state: The current agent state.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
@@ -188,29 +244,22 @@ class StateFileSearchMiddleware(AgentMiddleware):
return "\n".join(file_paths)
@tool
def grep_search( # noqa: D417
def _handle_grep_search(
self,
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal[
"files_with_matches", "content", "count"
] = "files_with_matches",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
path: str,
include: str | None,
output_mode: str,
state: AnthropicToolsState,
) -> str:
"""Fast content search tool that works with any codebase size.
Searches file contents using regular expressions. Supports full regex
syntax and filters files by pattern with the include parameter.
"""Handle grep search operation.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
path: The directory to search in.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format:
- "files_with_matches": Only file paths containing matches (default)
- "content": Matching lines with file:line:content format
- "count": Count of matches per file
output_mode: Output format.
state: The current agent state.
Returns:
Search results formatted according to output_mode. Returns "No matches
@@ -255,10 +304,6 @@ class StateFileSearchMiddleware(AgentMiddleware):
# Format output based on mode
return self._format_grep_results(results, output_mode)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _format_grep_results(
self,
results: dict[str, list[tuple[int, str]]],

View File

@@ -49,8 +49,10 @@ class TestGlobSearch:
},
}
# Call tool function directly (state is injected in real usage)
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
# Call internal handler method directly
result = middleware._handle_glob_search(
pattern="*.py", path="/", state=test_state
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -82,7 +84,9 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/", state=state
)
assert isinstance(result, str)
lines = result.split("\n")
@@ -109,7 +113,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func( # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/src", state=state
)
@@ -132,7 +136,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
assert isinstance(result, str)
assert result == "No files found"
@@ -157,7 +161,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
lines = result.split("\n")
# Most recent first
@@ -193,7 +197,13 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -216,8 +226,12 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def", include="*.{py", state=state
result = middleware._handle_grep_search(
pattern=r"def",
path="/",
include="*.{py",
output_mode="files_with_matches",
state=state,
)
assert result == "Invalid include pattern"
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def \w+\(\):", output_mode="content", state=state
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="content",
state=state,
)
assert isinstance(result, str)
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"TODO", output_mode="count", state=state
result = middleware._handle_grep_search(
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
)
assert isinstance(result, str)
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="import", include="*.py", state=state
result = middleware._handle_grep_search(
pattern="import",
path="/",
include="*.py",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="const", include="*.{ts,tsx}", state=state
result = middleware._handle_grep_search(
pattern="const",
path="/",
include="*.{ts,tsx}",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern="import",
path="/src",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert result == "No matches found"
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
"text_editor_files": {},
}
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"[unclosed",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "Invalid regex pattern" in result
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r".*",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result