From 83c078f363f020fba6abce8a3f031376a7fdd27d Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Fri, 14 Nov 2025 09:13:39 -0500 Subject: [PATCH] fix: adding missing async hooks (#33957) * filling in missing async gaps * using recommended tool runtime injection instead of injected state * updating tests to use helper function as well --- .../agents/middleware/human_in_the_loop.py | 4 + .../agents/middleware/model_call_limit.py | 39 ++++ .../langchain/agents/middleware/pii.py | 41 ++++ .../langchain/agents/middleware/shell_tool.py | 7 +- .../agents/middleware/tool_call_limit.py | 25 +++ .../langchain_anthropic/middleware/bash.py | 12 + .../middleware/file_search.py | 209 +++++++++++------- .../unit_tests/middleware/test_file_search.py | 102 +++++++-- 8 files changed, 331 insertions(+), 108 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index a7c21b0c8e2..355de48c73f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): last_ai_msg.tool_calls = revised_tool_calls return {"messages": [last_ai_msg, *artificial_tool_messages]} + + async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + """Async trigger interrupt flows for relevant tool calls after an `AIMessage`.""" + return self.after_model(state, runtime) diff --git a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py index d44c26bad3d..87afda778ee 100644 --- a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py +++ b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py @@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): return None + @hook_config(can_jump_to=["end"]) + async def abefore_model( + self, + state: ModelCallLimitState, + runtime: Runtime, + ) -> dict[str, Any] | None: + """Async check model call limits before making a model call. + + Args: + state: The current agent state containing call counts. + runtime: The langgraph runtime. + + Returns: + If limits are exceeded and exit_behavior is `'end'`, returns + a `Command` to jump to the end with a limit exceeded message. Otherwise + returns `None`. + + Raises: + ModelCallLimitExceededError: If limits are exceeded and `exit_behavior` + is `'error'`. + """ + return self.before_model(state, runtime) + def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002 """Increment model call counts after a model call. @@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): "thread_model_call_count": state.get("thread_model_call_count", 0) + 1, "run_model_call_count": state.get("run_model_call_count", 0) + 1, } + + async def aafter_model( + self, + state: ModelCallLimitState, + runtime: Runtime, + ) -> dict[str, Any] | None: + """Async increment model call counts after a model call. + + Args: + state: The current agent state. + runtime: The langgraph runtime. + + Returns: + State updates with incremented call counts. + """ + return self.after_model(state, runtime) diff --git a/libs/langchain_v1/langchain/agents/middleware/pii.py b/libs/langchain_v1/langchain/agents/middleware/pii.py index 5a7eb581a6c..4f11e709fbd 100644 --- a/libs/langchain_v1/langchain/agents/middleware/pii.py +++ b/libs/langchain_v1/langchain/agents/middleware/pii.py @@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware): return None + @hook_config(can_jump_to=["end"]) + async def abefore_model( + self, + state: AgentState, + runtime: Runtime, + ) -> dict[str, Any] | None: + """Async check user messages and tool results for PII before model invocation. + + Args: + state: The current agent state. + runtime: The langgraph runtime. + + Returns: + Updated state with PII handled according to strategy, or `None` if no PII + detected. + + Raises: + PIIDetectionError: If PII is detected and strategy is `'block'`. + """ + return self.before_model(state, runtime) + def after_model( self, state: AgentState, @@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware): return {"messages": new_messages} + async def aafter_model( + self, + state: AgentState, + runtime: Runtime, + ) -> dict[str, Any] | None: + """Async check AI messages for PII after model invocation. + + Args: + state: The current agent state. + runtime: The langgraph runtime. + + Returns: + Updated state with PII handled according to strategy, or None if no PII + detected. + + Raises: + PIIDetectionError: If PII is detected and strategy is `'block'`. + """ + return self.after_model(state, runtime) + __all__ = [ "PIIDetectionError", diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 8c03cfe28a8..e8a6fd2c524 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -482,7 +482,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): return {"shell_session_resources": resources} async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: - """Async counterpart to `before_agent`.""" + """Async start the shell session and run startup commands.""" return self.before_agent(state, runtime) def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002 @@ -494,7 +494,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): resources._finalizer() async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None: - """Async counterpart to `after_agent`.""" + """Async run shutdown commands and release resources when an agent completes.""" return self.after_agent(state, runtime) def _ensure_resources(self, state: ShellToolState) -> _SessionResources: @@ -689,7 +689,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): request: ToolCallRequest, handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]], ) -> ToolMessage | Command: - """Async interception mirroring the synchronous tool handler.""" + """Async intercept local shell tool calls and execute them via the managed session.""" + # The sync version already handles all the work, no need for async-specific logic if isinstance(request.tool, _PersistentShellTool): resources = self._ensure_resources(request.state) return self._run_shell_tool( diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py b/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py index c01cf536315..9aa2d116218 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py @@ -451,3 +451,28 @@ class ToolCallLimitMiddleware( "run_tool_call_count": run_counts, "messages": artificial_messages, } + + @hook_config(can_jump_to=["end"]) + async def aafter_model( + self, + state: ToolCallLimitState[ResponseT], + runtime: Runtime[ContextT], + ) -> dict[str, Any] | None: + """Async increment tool call counts after a model call and check limits. + + Args: + state: The current agent state. + runtime: The langgraph runtime. + + Returns: + State updates with incremented tool call counts. If limits are exceeded + and exit_behavior is `'end'`, also includes a jump to end with a + `ToolMessage` and AI message for the single exceeded tool call. + + Raises: + ToolCallLimitExceededError: If limits are exceeded and `exit_behavior` + is `'error'`. + NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`, + and there are multiple tool calls. + """ + return self.after_model(state, runtime) diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py index 61184b1a037..2f8ef0c3135 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py @@ -40,6 +40,18 @@ class ClaudeBashToolMiddleware(ShellToolMiddleware): request = request.override(tools=tools) return handler(request) + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + """Async: ensure the Claude bash descriptor is available to the model.""" + tools = request.tools + if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): + tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] + request = request.override(tools=tools) + return await handler(request) + def wrap_tool_call( self, request: ToolCallRequest, diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/file_search.py b/libs/partners/anthropic/langchain_anthropic/middleware/file_search.py index 977a19727be..a60122e40be 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/file_search.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/file_search.py @@ -9,13 +9,13 @@ from __future__ import annotations import fnmatch import re from pathlib import Path, PurePosixPath -from typing import TYPE_CHECKING, Annotated, Literal, cast +from typing import TYPE_CHECKING, Literal, cast if TYPE_CHECKING: from typing import Any from langchain.agents.middleware.types import AgentMiddleware -from langchain_core.tools import InjectedToolArg, tool +from langchain.tools import ToolRuntime, tool from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState @@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware): # Create tool instances @tool def glob_search( # noqa: D417 + runtime: ToolRuntime[None, AnthropicToolsState], pattern: str, path: str = "/", - state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment] ) -> str: """Fast file pattern matching tool that works with any codebase size. @@ -147,56 +147,17 @@ class StateFileSearchMiddleware(AgentMiddleware): time (most recently modified first). Returns "No files found" if no matches. """ - # Normalize base path - base_path = path if path.startswith("/") else "/" + path - - # Get files from state - files = cast("dict[str, Any]", state.get(self.state_key, {})) - - # Match files - matches = [] - for file_path, file_data in files.items(): - if file_path.startswith(base_path): - # Get relative path from base - if base_path == "/": - relative = file_path[1:] # Remove leading / - elif file_path == base_path: - relative = Path(file_path).name - elif file_path.startswith(base_path + "/"): - relative = file_path[len(base_path) + 1 :] - else: - continue - - # Match against pattern - # Handle ** pattern which requires special care - # PurePosixPath.match doesn't match single-level paths - # against **/pattern - is_match = PurePosixPath(relative).match(pattern) - if not is_match and pattern.startswith("**/"): - # Also try matching without the **/ prefix for files in base dir - is_match = PurePosixPath(relative).match(pattern[3:]) - - if is_match: - matches.append((file_path, file_data["modified_at"])) - - if not matches: - return "No files found" - - # Sort by modification time - matches.sort(key=lambda x: x[1], reverse=True) - file_paths = [path for path, _ in matches] - - return "\n".join(file_paths) + return self._handle_glob_search(pattern, path, runtime.state) @tool def grep_search( # noqa: D417 + runtime: ToolRuntime[None, AnthropicToolsState], pattern: str, path: str = "/", include: str | None = None, output_mode: Literal[ "files_with_matches", "content", "count" ] = "files_with_matches", - state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment] ) -> str: """Fast content search tool that works with any codebase size. @@ -216,49 +177,133 @@ class StateFileSearchMiddleware(AgentMiddleware): Search results formatted according to output_mode. Returns "No matches found" if no results. """ - # Normalize base path - base_path = path if path.startswith("/") else "/" + path - - # Compile regex pattern (for validation) - try: - regex = re.compile(pattern) - except re.error as e: - return f"Invalid regex pattern: {e}" - - if include and not _is_valid_include_pattern(include): - return "Invalid include pattern" - - # Search files - files = cast("dict[str, Any]", state.get(self.state_key, {})) - results: dict[str, list[tuple[int, str]]] = {} - - for file_path, file_data in files.items(): - if not file_path.startswith(base_path): - continue - - # Check include filter - if include: - basename = Path(file_path).name - if not _match_include_pattern(basename, include): - continue - - # Search file content - for line_num, line in enumerate(file_data["content"], 1): - if regex.search(line): - if file_path not in results: - results[file_path] = [] - results[file_path].append((line_num, line)) - - if not results: - return "No matches found" - - # Format output based on mode - return self._format_grep_results(results, output_mode) + return self._handle_grep_search( + pattern, path, include, output_mode, runtime.state + ) self.glob_search = glob_search self.grep_search = grep_search self.tools = [glob_search, grep_search] + def _handle_glob_search( + self, + pattern: str, + path: str, + state: AnthropicToolsState, + ) -> str: + """Handle glob search operation. + + Args: + pattern: The glob pattern to match files against. + path: The directory to search in. + state: The current agent state. + + Returns: + Newline-separated list of matching file paths, sorted by modification + time (most recently modified first). Returns "No files found" if no + matches. + """ + # Normalize base path + base_path = path if path.startswith("/") else "/" + path + + # Get files from state + files = cast("dict[str, Any]", state.get(self.state_key, {})) + + # Match files + matches = [] + for file_path, file_data in files.items(): + if file_path.startswith(base_path): + # Get relative path from base + if base_path == "/": + relative = file_path[1:] # Remove leading / + elif file_path == base_path: + relative = Path(file_path).name + elif file_path.startswith(base_path + "/"): + relative = file_path[len(base_path) + 1 :] + else: + continue + + # Match against pattern + # Handle ** pattern which requires special care + # PurePosixPath.match doesn't match single-level paths + # against **/pattern + is_match = PurePosixPath(relative).match(pattern) + if not is_match and pattern.startswith("**/"): + # Also try matching without the **/ prefix for files in base dir + is_match = PurePosixPath(relative).match(pattern[3:]) + + if is_match: + matches.append((file_path, file_data["modified_at"])) + + if not matches: + return "No files found" + + # Sort by modification time + matches.sort(key=lambda x: x[1], reverse=True) + file_paths = [path for path, _ in matches] + + return "\n".join(file_paths) + + def _handle_grep_search( + self, + pattern: str, + path: str, + include: str | None, + output_mode: str, + state: AnthropicToolsState, + ) -> str: + """Handle grep search operation. + + Args: + pattern: The regular expression pattern to search for in file contents. + path: The directory to search in. + include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}"). + output_mode: Output format. + state: The current agent state. + + Returns: + Search results formatted according to output_mode. Returns "No matches + found" if no results. + """ + # Normalize base path + base_path = path if path.startswith("/") else "/" + path + + # Compile regex pattern (for validation) + try: + regex = re.compile(pattern) + except re.error as e: + return f"Invalid regex pattern: {e}" + + if include and not _is_valid_include_pattern(include): + return "Invalid include pattern" + + # Search files + files = cast("dict[str, Any]", state.get(self.state_key, {})) + results: dict[str, list[tuple[int, str]]] = {} + + for file_path, file_data in files.items(): + if not file_path.startswith(base_path): + continue + + # Check include filter + if include: + basename = Path(file_path).name + if not _match_include_pattern(basename, include): + continue + + # Search file content + for line_num, line in enumerate(file_data["content"], 1): + if regex.search(line): + if file_path not in results: + results[file_path] = [] + results[file_path].append((line_num, line)) + + if not results: + return "No matches found" + + # Format output based on mode + return self._format_grep_results(results, output_mode) + def _format_grep_results( self, results: dict[str, list[tuple[int, str]]], diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_file_search.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_file_search.py index 395a368034c..1c1319961d7 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_file_search.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_file_search.py @@ -49,8 +49,10 @@ class TestGlobSearch: }, } - # Call tool function directly (state is injected in real usage) - result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined] + # Call internal handler method directly + result = middleware._handle_glob_search( + pattern="*.py", path="/", state=test_state + ) assert isinstance(result, str) assert "/src/main.py" in result @@ -82,7 +84,9 @@ class TestGlobSearch: }, } - result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined] + result = middleware._handle_glob_search( + pattern="**/*.py", path="/", state=state + ) assert isinstance(result, str) lines = result.split("\n") @@ -109,7 +113,7 @@ class TestGlobSearch: }, } - result = middleware.glob_search.func( # type: ignore[attr-defined] + result = middleware._handle_glob_search( pattern="**/*.py", path="/src", state=state ) @@ -132,7 +136,7 @@ class TestGlobSearch: }, } - result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined] + result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state) assert isinstance(result, str) assert result == "No files found" @@ -157,7 +161,7 @@ class TestGlobSearch: }, } - result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined] + result = middleware._handle_glob_search(pattern="*.py", path="/", state=state) lines = result.split("\n") # Most recent first @@ -193,7 +197,13 @@ class TestGrepSearch: }, } - result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern=r"def \w+\(\):", + path="/", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert "/src/main.py" in result @@ -216,8 +226,12 @@ class TestGrepSearch: }, } - result = middleware.grep_search.func( # type: ignore[attr-defined] - pattern=r"def", include="*.{py", state=state + result = middleware._handle_grep_search( + pattern=r"def", + path="/", + include="*.{py", + output_mode="files_with_matches", + state=state, ) assert result == "Invalid include pattern" @@ -241,8 +255,12 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func( # type: ignore[attr-defined] - pattern=r"def \w+\(\):", output_mode="content", state=state + result = middleware._handle_grep_search( + pattern=r"def \w+\(\):", + path="/", + include=None, + output_mode="content", + state=state, ) assert isinstance(result, str) @@ -271,8 +289,8 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func( # type: ignore[attr-defined] - pattern=r"TODO", output_mode="count", state=state + result = middleware._handle_grep_search( + pattern=r"TODO", path="/", include=None, output_mode="count", state=state ) assert isinstance(result, str) @@ -300,8 +318,12 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func( # type: ignore[attr-defined] - pattern="import", include="*.py", state=state + result = middleware._handle_grep_search( + pattern="import", + path="/", + include="*.py", + output_mode="files_with_matches", + state=state, ) assert isinstance(result, str) @@ -333,8 +355,12 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func( # type: ignore[attr-defined] - pattern="const", include="*.{ts,tsx}", state=state + result = middleware._handle_grep_search( + pattern="const", + path="/", + include="*.{ts,tsx}", + output_mode="files_with_matches", + state=state, ) assert isinstance(result, str) @@ -362,7 +388,13 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern="import", + path="/src", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert "/src/main.py" in result @@ -383,7 +415,13 @@ class TestFilesystemGrepSearch: }, } - result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern=r"TODO", + path="/", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert result == "No matches found" @@ -397,7 +435,13 @@ class TestFilesystemGrepSearch: "text_editor_files": {}, } - result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern=r"[unclosed", + path="/", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert "Invalid regex pattern" in result @@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends: }, } - result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined] + result = middleware._handle_glob_search(pattern="**/*", path="/", state=state) assert isinstance(result, str) assert "/src/main.py" in result @@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends: }, } - result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern=r"TODO", + path="/", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert "/src/main.py" in result @@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends: }, } - result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined] + result = middleware._handle_grep_search( + pattern=r".*", + path="/", + include=None, + output_mode="files_with_matches", + state=state, + ) assert isinstance(result, str) assert "/src/main.py" in result