mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(anthropic): execute bash + file tools via tool node (#33960)
* use `override` instead of directly patching things on `ModelRequest` * rely on `ToolNode` for execution of tools related to said middleware, using `wrap_model_call` to inject the relevant claude tool specs + allowing tool node to forward them along to corresponding langchain tool implementations * making the same change for the native shell tool middleware * allowing shell tool middleware to specify a name for the shell tool (negative diff then for claude bash middleware) long term I think the solution might be to attach metadata to a tool to map the provider spec to a langchain implementation, which we could also take some lessons from on the MCP front.
This commit is contained in:
@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
_ModelRequestOverrides,
|
||||
)
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
||||
# Tool type constants
|
||||
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
|
||||
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
|
||||
self.allowed_prefixes = allowed_path_prefixes
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# Create tool that will be executed by the tool node
|
||||
@tool(tool_name)
|
||||
def file_tool(
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
command: str,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
new_path: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
) -> Command | str:
|
||||
"""Execute file operations on virtual file system.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime providing access to state.
|
||||
command: Operation to perform.
|
||||
path: File path to operate on.
|
||||
file_text: Full file content for create command.
|
||||
old_str: String to replace for str_replace command.
|
||||
new_str: Replacement string for str_replace command.
|
||||
insert_line: Line number for insert command.
|
||||
new_path: New path for rename command.
|
||||
view_range: Line range [start, end] for view command.
|
||||
|
||||
Returns:
|
||||
Command for state update or string result.
|
||||
"""
|
||||
# Build args dict for handler methods
|
||||
args: dict[str, Any] = {"path": path}
|
||||
if file_text is not None:
|
||||
args["file_text"] = file_text
|
||||
if old_str is not None:
|
||||
args["old_str"] = old_str
|
||||
if new_str is not None:
|
||||
args["new_str"] = new_str
|
||||
if insert_line is not None:
|
||||
args["insert_line"] = insert_line
|
||||
if new_path is not None:
|
||||
args["new_path"] = new_path
|
||||
if view_range is not None:
|
||||
args["view_range"] = view_range
|
||||
|
||||
# Route to appropriate handler based on command
|
||||
try:
|
||||
if command == "view":
|
||||
return self._handle_view(args, runtime.state, runtime.tool_call_id)
|
||||
if command == "create":
|
||||
return self._handle_create(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "insert":
|
||||
return self._handle_insert(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "delete":
|
||||
return self._handle_delete(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "rename":
|
||||
return self._handle_rename(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
return f"Unknown command: {command}"
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return str(e)
|
||||
|
||||
self.tools = [file_tool]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.override(**overrides))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt (async version)."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
state = request.state
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, state, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, state, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, state, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, state, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, state, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, state, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls (async version)."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return await handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
state = request.state
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, state, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, state, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, state, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, state, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, state, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, state, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
return await handler(request.override(**overrides))
|
||||
|
||||
def _handle_view(
|
||||
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
|
||||
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
|
||||
# Create root directory if it doesn't exist
|
||||
self.root_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create tool that will be executed by the tool node
|
||||
@tool(tool_name)
|
||||
def file_tool(
|
||||
runtime: ToolRuntime,
|
||||
command: str,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
new_path: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
) -> Command | str:
|
||||
"""Execute file operations on filesystem.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime providing tool_call_id.
|
||||
command: Operation to perform.
|
||||
path: File path to operate on.
|
||||
file_text: Full file content for create command.
|
||||
old_str: String to replace for str_replace command.
|
||||
new_str: Replacement string for str_replace command.
|
||||
insert_line: Line number for insert command.
|
||||
new_path: New path for rename command.
|
||||
view_range: Line range [start, end] for view command.
|
||||
|
||||
Returns:
|
||||
Command for message update or string result.
|
||||
"""
|
||||
# Build args dict for handler methods
|
||||
args: dict[str, Any] = {"path": path}
|
||||
if file_text is not None:
|
||||
args["file_text"] = file_text
|
||||
if old_str is not None:
|
||||
args["old_str"] = old_str
|
||||
if new_str is not None:
|
||||
args["new_str"] = new_str
|
||||
if insert_line is not None:
|
||||
args["insert_line"] = insert_line
|
||||
if new_path is not None:
|
||||
args["new_path"] = new_path
|
||||
if view_range is not None:
|
||||
args["view_range"] = view_range
|
||||
|
||||
# Route to appropriate handler based on command
|
||||
try:
|
||||
if command == "view":
|
||||
return self._handle_view(args, runtime.tool_call_id)
|
||||
if command == "create":
|
||||
return self._handle_create(args, runtime.tool_call_id)
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, runtime.tool_call_id)
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, runtime.tool_call_id)
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, runtime.tool_call_id)
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, runtime.tool_call_id)
|
||||
return f"Unknown command: {command}"
|
||||
except (ValueError, FileNotFoundError, PermissionError) as e:
|
||||
return str(e)
|
||||
|
||||
self.tools = [file_tool]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
return handler(request.override(**overrides))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt (async version)."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls (async version)."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return await handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
return await handler(request.override(**overrides))
|
||||
|
||||
def _validate_and_resolve_path(self, path: str) -> Path:
|
||||
"""Validate and resolve a virtual path to filesystem path.
|
||||
|
||||
@@ -3,105 +3,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
|
||||
# Tool type constants for Anthropic
|
||||
BASH_TOOL_TYPE = "bash_20250124"
|
||||
BASH_TOOL_NAME = "bash"
|
||||
|
||||
|
||||
class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||
"""Middleware that exposes Anthropic's native bash tool to models."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize middleware without registering a client-side tool."""
|
||||
kwargs["shell_command"] = ("/bin/bash",)
|
||||
super().__init__(*args, **kwargs)
|
||||
# Remove the base tool so Claude's native descriptor is the sole entry.
|
||||
self._tool = None # type: ignore[assignment]
|
||||
self.tools = []
|
||||
def __init__(
|
||||
self,
|
||||
workspace_root: str | None = None,
|
||||
*,
|
||||
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
execution_policy: Any | None = None,
|
||||
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
|
||||
tool_description: str | None = None,
|
||||
env: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize middleware for Claude's native bash tool.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session.
|
||||
If omitted, a temporary directory is created.
|
||||
startup_commands: Optional commands executed after the session starts.
|
||||
shutdown_commands: Optional commands executed before session shutdown.
|
||||
execution_policy: Execution policy controlling timeouts and limits.
|
||||
redaction_rules: Optional redaction rules to sanitize output.
|
||||
tool_description: Optional override for tool description.
|
||||
env: Optional environment variables for the shell session.
|
||||
"""
|
||||
super().__init__(
|
||||
workspace_root=workspace_root,
|
||||
startup_commands=startup_commands,
|
||||
shutdown_commands=shutdown_commands,
|
||||
execution_policy=execution_policy,
|
||||
redaction_rules=redaction_rules,
|
||||
tool_description=tool_description,
|
||||
tool_name=BASH_TOOL_NAME,
|
||||
shell_command=("/bin/bash",),
|
||||
env=env,
|
||||
)
|
||||
# Parent class now creates the tool with name "bash" via tool_name parameter
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Ensure the Claude bash descriptor is available to the model."""
|
||||
tools = request.tools
|
||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||
request = request.override(tools=tools)
|
||||
return handler(request)
|
||||
"""Replace parent's shell tool with Claude's bash descriptor."""
|
||||
filtered = [
|
||||
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||
]
|
||||
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||
return handler(request.override(tools=tools))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Async: ensure the Claude bash descriptor is available to the model."""
|
||||
tools = request.tools
|
||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||
request = request.override(tools=tools)
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Command | ToolMessage],
|
||||
) -> Command | ToolMessage:
|
||||
"""Intercept Claude bash tool calls and execute them locally."""
|
||||
tool_call = request.tool_call
|
||||
if tool_call.get("name") != "bash":
|
||||
return handler(request)
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
tool_call["args"],
|
||||
tool_call_id=tool_call.get("id"),
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
|
||||
) -> Command | ToolMessage:
|
||||
"""Async interception mirroring the synchronous implementation."""
|
||||
tool_call = request.tool_call
|
||||
if tool_call.get("name") != "bash":
|
||||
return await handler(request)
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
tool_call["args"],
|
||||
tool_call_id=tool_call.get("id"),
|
||||
)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_call_id: str | None,
|
||||
*,
|
||||
status: Literal["success", "error"],
|
||||
artifact: dict[str, Any] | None = None,
|
||||
) -> ToolMessage | str:
|
||||
"""Format tool responses using Claude's bash descriptor."""
|
||||
if tool_call_id is None:
|
||||
return content
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=_CLAUDE_BASH_DESCRIPTOR["name"],
|
||||
status=status,
|
||||
artifact=artifact or {},
|
||||
)
|
||||
"""Async: replace parent's shell tool with Claude's bash descriptor."""
|
||||
filtered = [
|
||||
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||
]
|
||||
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||
return await handler(request.override(tools=tools))
|
||||
|
||||
|
||||
__all__ = ["ClaudeBashToolMiddleware"]
|
||||
|
||||
@@ -3,67 +3,59 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
|
||||
pytest.importorskip(
|
||||
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
|
||||
)
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
||||
|
||||
|
||||
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
|
||||
monkeypatch.setattr(
|
||||
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
|
||||
# Should have exactly one tool registered (from parent)
|
||||
assert len(middleware.tools) == 1
|
||||
|
||||
# Tool is named "bash" (via tool_name parameter)
|
||||
bash_tool = middleware.tools[0]
|
||||
assert bash_tool.name == "bash"
|
||||
|
||||
|
||||
def test_replaces_tool_with_claude_descriptor() -> None:
|
||||
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
|
||||
# Create a mock request with the bash tool (inherited from parent)
|
||||
bash_tool = middleware.tools[0]
|
||||
request = ModelRequest(
|
||||
model=MagicMock(),
|
||||
system_prompt=None,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[bash_tool],
|
||||
response_format=None,
|
||||
state={"messages": []},
|
||||
runtime=MagicMock(),
|
||||
)
|
||||
|
||||
tool_call: ToolCall = {
|
||||
# Mock handler that captures the modified request
|
||||
captured_request = None
|
||||
|
||||
def handler(req: ModelRequest) -> MagicMock:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return MagicMock()
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
|
||||
# The bash tool should be replaced with Claude's native bash descriptor
|
||||
assert captured_request is not None
|
||||
assert len(captured_request.tools) == 1
|
||||
assert captured_request.tools[0] == {
|
||||
"type": "bash_20250124",
|
||||
"name": "bash",
|
||||
"args": {"command": "echo hi"},
|
||||
"id": "call-1",
|
||||
}
|
||||
request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=MagicMock(),
|
||||
state={},
|
||||
runtime=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
handler_called = False
|
||||
|
||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ToolMessage(content="should not be used", tool_call_id="call-1")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
assert result is sentinel
|
||||
assert handler_called is False
|
||||
|
||||
|
||||
def test_wrap_tool_call_passes_through_other_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
|
||||
request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=MagicMock(),
|
||||
state={},
|
||||
runtime=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
|
||||
|
||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||
return sentinel
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
assert result is sentinel
|
||||
|
||||
Reference in New Issue
Block a user