fix(anthropic): execute bash + file tools via tool node (#33960)

* use `override` instead of directly patching things on `ModelRequest`
* rely on `ToolNode` for execution of tools related to said middleware,
using `wrap_model_call` to inject the relevant claude tool specs +
allowing tool node to forward them along to corresponding langchain tool
implementations
* making the same change for the native shell tool middleware
* allowing shell tool middleware to specify a name for the shell tool
(negative diff then for claude bash middleware)


long term I think the solution might be to attach metadata to a tool to
map the provider spec to a langchain implementation, which we could also
take some lessons from on the MCP front.
This commit is contained in:
Sydney Runkle
2025-11-14 13:17:01 -05:00
committed by GitHub
parent d2942351ce
commit 1bc88028e6
4 changed files with 312 additions and 412 deletions

View File

@@ -11,7 +11,6 @@ import subprocess
import tempfile import tempfile
import threading import threading
import time import time
import typing
import uuid import uuid
import weakref import weakref
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -19,9 +18,10 @@ from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException from langchain_core.tools.base import ToolException
from langgraph.channels.untracked_value import UntrackedValue from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired from typing_extensions import NotRequired
from langchain.agents.middleware._execution import ( from langchain.agents.middleware._execution import (
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
ResolvedRedactionRule, ResolvedRedactionRule,
) )
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
from langchain.tools import ToolRuntime, tool
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__" _DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
"session remains stable. Outputs may be truncated when they become very large, and long " "session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses." "running commands will be terminated once their configured timeout elapses."
) )
SHELL_TOOL_NAME = "shell"
def _cleanup_resources( def _cleanup_resources(
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool.""" """Input schema for the persistent shell tool."""
command: str | None = None command: str | None = None
"""The shell command to execute."""
restart: bool | None = None restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
injected ToolRuntime.
"""
@model_validator(mode="after") @model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput: def validate_payload(self) -> _ShellToolInput:
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
return self return self
class _PersistentShellTool(BaseTool):
"""Tool wrapper that relies on middleware interception for execution."""
name: str = "shell"
description: str = DEFAULT_TOOL_DESCRIPTION
args_schema: type[BaseModel] = _ShellToolInput
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
super().__init__()
self._middleware = middleware
if description is not None:
self.description = description
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
raise RuntimeError(msg)
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Middleware that registers a persistent shell tool for agents. """Middleware that registers a persistent shell tool for agents.
@@ -393,6 +385,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
execution_policy: BaseExecutionPolicy | None = None, execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None, redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None, tool_description: str | None = None,
tool_name: str = SHELL_TOOL_NAME,
shell_command: Sequence[str] | str | None = None, shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None, env: Mapping[str, Any] | None = None,
) -> None: ) -> None:
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
returning it to the model. returning it to the model.
tool_description: Optional override for the registered shell tool tool_description: Optional override for the registered shell tool
description. description.
tool_name: Name for the registered shell tool.
Defaults to `"shell"`.
shell_command: Optional shell executable (string) or argument sequence used shell_command: Optional shell executable (string) or argument sequence used
to launch the persistent session. to launch the persistent session.
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
""" """
super().__init__() super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None self._workspace_root = Path(workspace_root) if workspace_root else None
self._tool_name = tool_name
self._shell_command = self._normalize_shell_command(shell_command) self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env) self._environment = self._normalize_env(env)
if execution_policy is not None: if execution_policy is not None:
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
self._startup_commands = self._normalize_commands(startup_commands) self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands) self._shutdown_commands = self._normalize_commands(shutdown_commands)
# Create a proper tool that executes directly (no interception needed)
description = tool_description or DEFAULT_TOOL_DESCRIPTION description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool] @tool(self._tool_name, args_schema=_ShellToolInput, description=description)
def shell_tool(
*,
runtime: ToolRuntime[None, ShellToolState],
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._ensure_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
tool_call_id=runtime.tool_call_id,
)
self._shell_tool = shell_tool
self.tools = [self._shell_tool]
@staticmethod @staticmethod
def _normalize_commands( def _normalize_commands(
@@ -669,37 +682,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
artifact=artifact, artifact=artifact,
) )
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept local shell tool calls and execute them via the managed session."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return handler(request)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async intercept local shell tool calls and execute them via the managed session."""
# The sync version already handles all the work, no need for async-specific logic
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return await handler(request)
def _format_tool_message( def _format_tool_message(
self, self,
content: str, content: str,
@@ -714,7 +696,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return ToolMessage( return ToolMessage(
content=content, content=content,
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name=self._tool.name, name=self._tool_name,
status=status, status=status,
artifact=artifact, artifact=artifact,
) )

View File

@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
AgentState, AgentState,
ModelRequest, ModelRequest,
ModelResponse, ModelResponse,
_ModelRequestOverrides,
) )
from langchain.tools import ToolRuntime, tool
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langgraph.types import Command from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from langchain.agents.middleware.types import ToolCallRequest
# Tool type constants # Tool type constants
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728" TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
self.allowed_prefixes = allowed_path_prefixes self.allowed_prefixes = allowed_path_prefixes
self.system_prompt = system_prompt self.system_prompt = system_prompt
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime[None, AnthropicToolsState],
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on virtual file system.
Args:
runtime: Tool runtime providing access to state.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for state update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.state, runtime.tool_call_id)
if command == "create":
return self._handle_create(
args, runtime.state, runtime.tool_call_id
)
if command == "str_replace":
return self._handle_str_replace(
args, runtime.state, runtime.tool_call_id
)
if command == "insert":
return self._handle_insert(
args, runtime.state, runtime.tool_call_id
)
if command == "delete":
return self._handle_delete(
args, runtime.state, runtime.tool_call_id
)
if command == "rename":
return self._handle_rename(
args, runtime.state, runtime.tool_call_id
)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call( def wrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse: ) -> ModelResponse:
"""Inject tool and optional system prompt.""" """Inject Anthropic tool descriptor and optional system prompt."""
# Add tool # Replace our BaseTool with Anthropic's native tool descriptor
tools = list(request.tools or []) tools = [
tools.append( t
{ for t in (request.tools or [])
"type": self.tool_type, if getattr(t, "name", None) != self.tool_name
"name": self.tool_name, ] + [{"type": self.tool_type, "name": self.tool_name}]
}
)
request.tools = tools
# Inject system prompt if provided # Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt: if self.system_prompt:
request.system_prompt = ( overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt if request.system_prompt
else self.system_prompt else self.system_prompt
) )
return handler(request) return handler(request.override(**overrides))
async def awrap_model_call( async def awrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse: ) -> ModelResponse:
"""Inject tool and optional system prompt (async version).""" """Inject Anthropic tool descriptor and optional system prompt."""
# Add tool # Replace our BaseTool with Anthropic's native tool descriptor
tools = list(request.tools or []) tools = [
tools.append( t
{ for t in (request.tools or [])
"type": self.tool_type, if getattr(t, "name", None) != self.tool_name
"name": self.tool_name, ] + [{"type": self.tool_type, "name": self.tool_name}]
}
)
request.tools = tools
# Inject system prompt if provided # Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt: if self.system_prompt:
request.system_prompt = ( overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt if request.system_prompt
else self.system_prompt else self.system_prompt
) )
return await handler(request) return await handler(request.override(**overrides))
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
def _handle_view( def _handle_view(
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
# Create root directory if it doesn't exist # Create root directory if it doesn't exist
self.root_path.mkdir(parents=True, exist_ok=True) self.root_path.mkdir(parents=True, exist_ok=True)
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime,
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on filesystem.
Args:
runtime: Tool runtime providing tool_call_id.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for message update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.tool_call_id)
if command == "create":
return self._handle_create(args, runtime.tool_call_id)
if command == "str_replace":
return self._handle_str_replace(args, runtime.tool_call_id)
if command == "insert":
return self._handle_insert(args, runtime.tool_call_id)
if command == "delete":
return self._handle_delete(args, runtime.tool_call_id)
if command == "rename":
return self._handle_rename(args, runtime.tool_call_id)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError, PermissionError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call( def wrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse: ) -> ModelResponse:
"""Inject tool and optional system prompt.""" """Inject Anthropic tool descriptor and optional system prompt."""
# Add tool # Replace our BaseTool with Anthropic's native tool descriptor
tools = list(request.tools or []) tools = [
tools.append( t
{ for t in (request.tools or [])
"type": self.tool_type, if getattr(t, "name", None) != self.tool_name
"name": self.tool_name, ] + [{"type": self.tool_type, "name": self.tool_name}]
}
)
request.tools = tools
# Inject system prompt if provided # Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt: if self.system_prompt:
request.system_prompt = ( overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt if request.system_prompt
else self.system_prompt else self.system_prompt
) )
return handler(request)
return handler(request.override(**overrides))
async def awrap_model_call( async def awrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse: ) -> ModelResponse:
"""Inject tool and optional system prompt (async version).""" """Inject Anthropic tool descriptor and optional system prompt."""
# Add tool # Replace our BaseTool with Anthropic's native tool descriptor
tools = list(request.tools or []) tools = [
tools.append( t
{ for t in (request.tools or [])
"type": self.tool_type, if getattr(t, "name", None) != self.tool_name
"name": self.tool_name, ] + [{"type": self.tool_type, "name": self.tool_name}]
}
)
request.tools = tools
# Inject system prompt if provided # Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt: if self.system_prompt:
request.system_prompt = ( overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt if request.system_prompt
else self.system_prompt else self.system_prompt
) )
return await handler(request) return await handler(request.override(**overrides))
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
def _validate_and_resolve_path(self, path: str) -> Path: def _validate_and_resolve_path(self, path: str) -> Path:
"""Validate and resolve a virtual path to filesystem path. """Validate and resolve a virtual path to filesystem path.

View File

@@ -3,105 +3,81 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any, Literal from typing import Any
from langchain.agents.middleware.shell_tool import ShellToolMiddleware from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.agents.middleware.types import ( from langchain.agents.middleware.types import (
ModelRequest, ModelRequest,
ModelResponse, ModelResponse,
ToolCallRequest,
) )
from langchain_core.messages import ToolMessage
from langgraph.types import Command
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"} # Tool type constants for Anthropic
BASH_TOOL_TYPE = "bash_20250124"
BASH_TOOL_NAME = "bash"
class ClaudeBashToolMiddleware(ShellToolMiddleware): class ClaudeBashToolMiddleware(ShellToolMiddleware):
"""Middleware that exposes Anthropic's native bash tool to models.""" """Middleware that exposes Anthropic's native bash tool to models."""
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(
"""Initialize middleware without registering a client-side tool.""" self,
kwargs["shell_command"] = ("/bin/bash",) workspace_root: str | None = None,
super().__init__(*args, **kwargs) *,
# Remove the base tool so Claude's native descriptor is the sole entry. startup_commands: tuple[str, ...] | list[str] | str | None = None,
self._tool = None # type: ignore[assignment] shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
self.tools = [] execution_policy: Any | None = None,
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
tool_description: str | None = None,
env: dict[str, Any] | None = None,
) -> None:
"""Initialize middleware for Claude's native bash tool.
Args:
workspace_root: Base directory for the shell session.
If omitted, a temporary directory is created.
startup_commands: Optional commands executed after the session starts.
shutdown_commands: Optional commands executed before session shutdown.
execution_policy: Execution policy controlling timeouts and limits.
redaction_rules: Optional redaction rules to sanitize output.
tool_description: Optional override for tool description.
env: Optional environment variables for the shell session.
"""
super().__init__(
workspace_root=workspace_root,
startup_commands=startup_commands,
shutdown_commands=shutdown_commands,
execution_policy=execution_policy,
redaction_rules=redaction_rules,
tool_description=tool_description,
tool_name=BASH_TOOL_NAME,
shell_command=("/bin/bash",),
env=env,
)
# Parent class now creates the tool with name "bash" via tool_name parameter
def wrap_model_call( def wrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse], handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse: ) -> ModelResponse:
"""Ensure the Claude bash descriptor is available to the model.""" """Replace parent's shell tool with Claude's bash descriptor."""
tools = request.tools filtered = [
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] ]
request = request.override(tools=tools) tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return handler(request) return handler(request.override(tools=tools))
async def awrap_model_call( async def awrap_model_call(
self, self,
request: ModelRequest, request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]], handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse: ) -> ModelResponse:
"""Async: ensure the Claude bash descriptor is available to the model.""" """Async: replace parent's shell tool with Claude's bash descriptor."""
tools = request.tools filtered = [
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] ]
request = request.override(tools=tools) tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return await handler(request) return await handler(request.override(tools=tools))
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Command | ToolMessage],
) -> Command | ToolMessage:
"""Intercept Claude bash tool calls and execute them locally."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
) -> Command | ToolMessage:
"""Async interception mirroring the synchronous implementation."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return await handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
"""Format tool responses using Claude's bash descriptor."""
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=_CLAUDE_BASH_DESCRIPTOR["name"],
status=status,
artifact=artifact or {},
)
__all__ = ["ClaudeBashToolMiddleware"] __all__ = ["ClaudeBashToolMiddleware"]

View File

@@ -3,67 +3,59 @@ from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from langchain_core.messages.tool import ToolCall
pytest.importorskip( pytest.importorskip(
"anthropic", reason="Anthropic SDK is required for Claude middleware tests" "anthropic", reason="Anthropic SDK is required for Claude middleware tests"
) )
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None: def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
middleware = ClaudeBashToolMiddleware() middleware = ClaudeBashToolMiddleware()
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel)) # Should have exactly one tool registered (from parent)
monkeypatch.setattr( assert len(middleware.tools) == 1
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
# Tool is named "bash" (via tool_name parameter)
bash_tool = middleware.tools[0]
assert bash_tool.name == "bash"
def test_replaces_tool_with_claude_descriptor() -> None:
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
from langchain.agents.middleware.types import ModelRequest
middleware = ClaudeBashToolMiddleware()
# Create a mock request with the bash tool (inherited from parent)
bash_tool = middleware.tools[0]
request = ModelRequest(
model=MagicMock(),
system_prompt=None,
messages=[],
tool_choice=None,
tools=[bash_tool],
response_format=None,
state={"messages": []},
runtime=MagicMock(),
) )
tool_call: ToolCall = { # Mock handler that captures the modified request
captured_request = None
def handler(req: ModelRequest) -> MagicMock:
nonlocal captured_request
captured_request = req
return MagicMock()
middleware.wrap_model_call(request, handler)
# The bash tool should be replaced with Claude's native bash descriptor
assert captured_request is not None
assert len(captured_request.tools) == 1
assert captured_request.tools[0] == {
"type": "bash_20250124",
"name": "bash", "name": "bash",
"args": {"command": "echo hi"},
"id": "call-1",
} }
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
handler_called = False
def handler(_: ToolCallRequest) -> ToolMessage:
nonlocal handler_called
handler_called = True
return ToolMessage(content="should not be used", tool_call_id="call-1")
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel
assert handler_called is False
def test_wrap_tool_call_passes_through_other_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = ClaudeBashToolMiddleware()
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
def handler(_: ToolCallRequest) -> ToolMessage:
return sentinel
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel