mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix: adding missing async hooks (#33957)
* filling in missing async gaps * using recommended tool runtime injection instead of injected state * updating tests to use helper function as well
This commit is contained in:
@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Increment model call counts after a model call.
|
||||
|
||||
@@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
|
||||
@@ -482,7 +482,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async counterpart to `before_agent`."""
|
||||
"""Async start the shell session and run startup commands."""
|
||||
return self.before_agent(state, runtime)
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
@@ -494,7 +494,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
resources._finalizer()
|
||||
|
||||
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||||
"""Async counterpart to `after_agent`."""
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
@@ -689,7 +689,8 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Async interception mirroring the synchronous tool handler."""
|
||||
"""Async intercept local shell tool calls and execute them via the managed session."""
|
||||
# The sync version already handles all the work, no need for async-specific logic
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
|
||||
@@ -451,3 +451,28 @@ class ToolCallLimitMiddleware(
|
||||
"run_tool_call_count": run_counts,
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -40,6 +40,18 @@ class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||
request = request.override(tools=tools)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Async: ensure the Claude bash descriptor is available to the model."""
|
||||
tools = request.tools
|
||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||
request = request.override(tools=tools)
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
|
||||
@@ -9,13 +9,13 @@ from __future__ import annotations
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import TYPE_CHECKING, Annotated, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
from langchain_core.tools import InjectedToolArg, tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
|
||||
|
||||
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
# Create tool instances
|
||||
@tool
|
||||
def glob_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
@@ -142,6 +142,62 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
return self._handle_glob_search(pattern, path, runtime.state)
|
||||
|
||||
@tool
|
||||
def grep_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal[
|
||||
"files_with_matches", "content", "count"
|
||||
] = "files_with_matches",
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
Searches file contents using regular expressions. Supports full regex
|
||||
syntax and filters files by pattern with the include parameter.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format:
|
||||
- "files_with_matches": Only file paths containing matches (default)
|
||||
- "content": Matching lines with file:line:content format
|
||||
- "count": Count of matches per file
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
return self._handle_grep_search(
|
||||
pattern, path, include, output_mode, runtime.state
|
||||
)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _handle_glob_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Handle glob search operation.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
@@ -188,29 +244,22 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
|
||||
return "\n".join(file_paths)
|
||||
|
||||
@tool
|
||||
def grep_search( # noqa: D417
|
||||
def _handle_grep_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal[
|
||||
"files_with_matches", "content", "count"
|
||||
] = "files_with_matches",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
path: str,
|
||||
include: str | None,
|
||||
output_mode: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
Searches file contents using regular expressions. Supports full regex
|
||||
syntax and filters files by pattern with the include parameter.
|
||||
"""Handle grep search operation.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
path: The directory to search in.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format:
|
||||
- "files_with_matches": Only file paths containing matches (default)
|
||||
- "content": Matching lines with file:line:content format
|
||||
- "count": Count of matches per file
|
||||
output_mode: Output format.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
@@ -255,10 +304,6 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _format_grep_results(
|
||||
self,
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
|
||||
@@ -49,8 +49,10 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
# Call tool function directly (state is injected in real usage)
|
||||
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
|
||||
# Call internal handler method directly
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="*.py", path="/", state=test_state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -82,7 +84,9 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
@@ -109,7 +113,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func( # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/src", state=state
|
||||
)
|
||||
|
||||
@@ -132,7 +136,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No files found"
|
||||
@@ -157,7 +161,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
|
||||
|
||||
lines = result.split("\n")
|
||||
# Most recent first
|
||||
@@ -193,7 +197,13 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -216,8 +226,12 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def", include="*.{py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def",
|
||||
path="/",
|
||||
include="*.{py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def \w+\(\):", output_mode="content", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="content",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"TODO", output_mode="count", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="import", include="*.py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/",
|
||||
include="*.py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="const", include="*.{ts,tsx}", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="const",
|
||||
path="/",
|
||||
include="*.{ts,tsx}",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/src",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No matches found"
|
||||
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
|
||||
"text_editor_files": {},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"[unclosed",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Invalid regex pattern" in result
|
||||
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r".*",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
|
||||
Reference in New Issue
Block a user