fix(anthropic): execute bash + file tools via tool node (#33960)

* use `override` instead of directly patching things on `ModelRequest`
* rely on `ToolNode` for execution of tools related to said middleware,
using `wrap_model_call` to inject the relevant claude tool specs +
allowing tool node to forward them along to corresponding langchain tool
implementations
* making the same change for the native shell tool middleware
* allowing shell tool middleware to specify a name for the shell tool
(negative diff then for claude bash middleware)


long term I think the solution might be to attach metadata to a tool to
map the provider spec to a langchain implementation, which we could also
take some lessons from on the MCP front.
This commit is contained in:
Sydney Runkle
2025-11-14 13:17:01 -05:00
committed by GitHub
parent d2942351ce
commit 1bc88028e6
4 changed files with 312 additions and 412 deletions

View File

@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
AgentState,
ModelRequest,
ModelResponse,
_ModelRequestOverrides,
)
from langchain.tools import ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence
from langchain.agents.middleware.types import ToolCallRequest
# Tool type constants
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
self.allowed_prefixes = allowed_path_prefixes
self.system_prompt = system_prompt
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime[None, AnthropicToolsState],
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on virtual file system.
Args:
runtime: Tool runtime providing access to state.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for state update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.state, runtime.tool_call_id)
if command == "create":
return self._handle_create(
args, runtime.state, runtime.tool_call_id
)
if command == "str_replace":
return self._handle_str_replace(
args, runtime.state, runtime.tool_call_id
)
if command == "insert":
return self._handle_insert(
args, runtime.state, runtime.tool_call_id
)
if command == "delete":
return self._handle_delete(
args, runtime.state, runtime.tool_call_id
)
if command == "rename":
return self._handle_rename(
args, runtime.state, runtime.tool_call_id
)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Inject tool and optional system prompt."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(**overrides))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Inject tool and optional system prompt (async version)."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
return await handler(request.override(**overrides))
def _handle_view(
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
# Create root directory if it doesn't exist
self.root_path.mkdir(parents=True, exist_ok=True)
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime,
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on filesystem.
Args:
runtime: Tool runtime providing tool_call_id.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for message update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.tool_call_id)
if command == "create":
return self._handle_create(args, runtime.tool_call_id)
if command == "str_replace":
return self._handle_str_replace(args, runtime.tool_call_id)
if command == "insert":
return self._handle_insert(args, runtime.tool_call_id)
if command == "delete":
return self._handle_delete(args, runtime.tool_call_id)
if command == "rename":
return self._handle_rename(args, runtime.tool_call_id)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError, PermissionError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Inject tool and optional system prompt."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(**overrides))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Inject tool and optional system prompt (async version)."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
return await handler(request.override(**overrides))
def _validate_and_resolve_path(self, path: str) -> Path:
"""Validate and resolve a virtual path to filesystem path.

View File

@@ -3,105 +3,81 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from typing import Any
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
)
from langchain_core.messages import ToolMessage
from langgraph.types import Command
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
# Tool type constants for Anthropic
BASH_TOOL_TYPE = "bash_20250124"
BASH_TOOL_NAME = "bash"
class ClaudeBashToolMiddleware(ShellToolMiddleware):
"""Middleware that exposes Anthropic's native bash tool to models."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize middleware without registering a client-side tool."""
kwargs["shell_command"] = ("/bin/bash",)
super().__init__(*args, **kwargs)
# Remove the base tool so Claude's native descriptor is the sole entry.
self._tool = None # type: ignore[assignment]
self.tools = []
def __init__(
self,
workspace_root: str | None = None,
*,
startup_commands: tuple[str, ...] | list[str] | str | None = None,
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
execution_policy: Any | None = None,
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
tool_description: str | None = None,
env: dict[str, Any] | None = None,
) -> None:
"""Initialize middleware for Claude's native bash tool.
Args:
workspace_root: Base directory for the shell session.
If omitted, a temporary directory is created.
startup_commands: Optional commands executed after the session starts.
shutdown_commands: Optional commands executed before session shutdown.
execution_policy: Execution policy controlling timeouts and limits.
redaction_rules: Optional redaction rules to sanitize output.
tool_description: Optional override for tool description.
env: Optional environment variables for the shell session.
"""
super().__init__(
workspace_root=workspace_root,
startup_commands=startup_commands,
shutdown_commands=shutdown_commands,
execution_policy=execution_policy,
redaction_rules=redaction_rules,
tool_description=tool_description,
tool_name=BASH_TOOL_NAME,
shell_command=("/bin/bash",),
env=env,
)
# Parent class now creates the tool with name "bash" via tool_name parameter
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return handler(request)
"""Replace parent's shell tool with Claude's bash descriptor."""
filtered = [
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
]
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return handler(request.override(tools=tools))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Async: ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Command | ToolMessage],
) -> Command | ToolMessage:
"""Intercept Claude bash tool calls and execute them locally."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
) -> Command | ToolMessage:
"""Async interception mirroring the synchronous implementation."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return await handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
"""Format tool responses using Claude's bash descriptor."""
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=_CLAUDE_BASH_DESCRIPTOR["name"],
status=status,
artifact=artifact or {},
)
"""Async: replace parent's shell tool with Claude's bash descriptor."""
filtered = [
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
]
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return await handler(request.override(tools=tools))
__all__ = ["ClaudeBashToolMiddleware"]

View File

@@ -3,67 +3,59 @@ from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from langchain_core.messages.tool import ToolCall
pytest.importorskip(
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
)
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
middleware = ClaudeBashToolMiddleware()
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
monkeypatch.setattr(
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
# Should have exactly one tool registered (from parent)
assert len(middleware.tools) == 1
# Tool is named "bash" (via tool_name parameter)
bash_tool = middleware.tools[0]
assert bash_tool.name == "bash"
def test_replaces_tool_with_claude_descriptor() -> None:
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
from langchain.agents.middleware.types import ModelRequest
middleware = ClaudeBashToolMiddleware()
# Create a mock request with the bash tool (inherited from parent)
bash_tool = middleware.tools[0]
request = ModelRequest(
model=MagicMock(),
system_prompt=None,
messages=[],
tool_choice=None,
tools=[bash_tool],
response_format=None,
state={"messages": []},
runtime=MagicMock(),
)
tool_call: ToolCall = {
# Mock handler that captures the modified request
captured_request = None
def handler(req: ModelRequest) -> MagicMock:
nonlocal captured_request
captured_request = req
return MagicMock()
middleware.wrap_model_call(request, handler)
# The bash tool should be replaced with Claude's native bash descriptor
assert captured_request is not None
assert len(captured_request.tools) == 1
assert captured_request.tools[0] == {
"type": "bash_20250124",
"name": "bash",
"args": {"command": "echo hi"},
"id": "call-1",
}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
handler_called = False
def handler(_: ToolCallRequest) -> ToolMessage:
nonlocal handler_called
handler_called = True
return ToolMessage(content="should not be used", tool_call_id="call-1")
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel
assert handler_called is False
def test_wrap_tool_call_passes_through_other_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = ClaudeBashToolMiddleware()
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
def handler(_: ToolCallRequest) -> ToolMessage:
return sentinel
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel