diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index e8a6fd2c524..7d30b867293 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -11,7 +11,6 @@ import subprocess import tempfile import threading import time -import typing import uuid import weakref from dataclasses import dataclass, field @@ -19,9 +18,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal from langchain_core.messages import ToolMessage -from langchain_core.tools.base import BaseTool, ToolException +from langchain_core.tools.base import ToolException from langgraph.channels.untracked_value import UntrackedValue from pydantic import BaseModel, model_validator +from pydantic.json_schema import SkipJsonSchema from typing_extensions import NotRequired from langchain.agents.middleware._execution import ( @@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import ( ResolvedRedactionRule, ) from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr +from langchain.tools import ToolRuntime, tool if TYPE_CHECKING: from collections.abc import Mapping, Sequence from langgraph.runtime import Runtime - from langgraph.types import Command - from langchain.agents.middleware.types import ToolCallRequest LOGGER = logging.getLogger(__name__) _DONE_MARKER_PREFIX = "__LC_SHELL_DONE__" @@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = ( "session remains stable. Outputs may be truncated when they become very large, and long " "running commands will be terminated once their configured timeout elapses." ) +SHELL_TOOL_NAME = "shell" def _cleanup_resources( @@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel): """Input schema for the persistent shell tool.""" command: str | None = None + """The shell command to execute.""" + restart: bool | None = None + """Whether to restart the shell session.""" + + runtime: Annotated[Any, SkipJsonSchema] = None + """The runtime for the shell tool. + + Included as a workaround at the moment bc args_schema doesn't work with + injected ToolRuntime. + """ @model_validator(mode="after") def validate_payload(self) -> _ShellToolInput: @@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel): return self -class _PersistentShellTool(BaseTool): - """Tool wrapper that relies on middleware interception for execution.""" - - name: str = "shell" - description: str = DEFAULT_TOOL_DESCRIPTION - args_schema: type[BaseModel] = _ShellToolInput - - def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None: - super().__init__() - self._middleware = middleware - if description is not None: - self.description = description - - def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper - msg = "Persistent shell tool execution should be intercepted via middleware wrappers." - raise RuntimeError(msg) - - class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): """Middleware that registers a persistent shell tool for agents. @@ -393,6 +385,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): execution_policy: BaseExecutionPolicy | None = None, redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None, tool_description: str | None = None, + tool_name: str = SHELL_TOOL_NAME, shell_command: Sequence[str] | str | None = None, env: Mapping[str, Any] | None = None, ) -> None: @@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): returning it to the model. tool_description: Optional override for the registered shell tool description. + tool_name: Name for the registered shell tool. + + Defaults to `"shell"`. shell_command: Optional shell executable (string) or argument sequence used to launch the persistent session. @@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): """ super().__init__() self._workspace_root = Path(workspace_root) if workspace_root else None + self._tool_name = tool_name self._shell_command = self._normalize_shell_command(shell_command) self._environment = self._normalize_env(env) if execution_policy is not None: @@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): self._startup_commands = self._normalize_commands(startup_commands) self._shutdown_commands = self._normalize_commands(shutdown_commands) + # Create a proper tool that executes directly (no interception needed) description = tool_description or DEFAULT_TOOL_DESCRIPTION - self._tool = _PersistentShellTool(self, description=description) - self.tools = [self._tool] + + @tool(self._tool_name, args_schema=_ShellToolInput, description=description) + def shell_tool( + *, + runtime: ToolRuntime[None, ShellToolState], + command: str | None = None, + restart: bool = False, + ) -> ToolMessage | str: + resources = self._ensure_resources(runtime.state) + return self._run_shell_tool( + resources, + {"command": command, "restart": restart}, + tool_call_id=runtime.tool_call_id, + ) + + self._shell_tool = shell_tool + self.tools = [self._shell_tool] @staticmethod def _normalize_commands( @@ -669,37 +682,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): artifact=artifact, ) - def wrap_tool_call( - self, - request: ToolCallRequest, - handler: typing.Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Intercept local shell tool calls and execute them via the managed session.""" - if isinstance(request.tool, _PersistentShellTool): - resources = self._ensure_resources(request.state) - return self._run_shell_tool( - resources, - request.tool_call["args"], - tool_call_id=request.tool_call.get("id"), - ) - return handler(request) - - async def awrap_tool_call( - self, - request: ToolCallRequest, - handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: - """Async intercept local shell tool calls and execute them via the managed session.""" - # The sync version already handles all the work, no need for async-specific logic - if isinstance(request.tool, _PersistentShellTool): - resources = self._ensure_resources(request.state) - return self._run_shell_tool( - resources, - request.tool_call["args"], - tool_call_id=request.tool_call.get("id"), - ) - return await handler(request) - def _format_tool_message( self, content: str, @@ -714,7 +696,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): return ToolMessage( content=content, tool_call_id=tool_call_id, - name=self._tool.name, + name=self._tool_name, status=status, artifact=artifact, ) diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py b/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py index b72dd1a62c8..eb19bc83383 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/anthropic_tools.py @@ -17,7 +17,9 @@ from langchain.agents.middleware.types import ( AgentState, ModelRequest, ModelResponse, + _ModelRequestOverrides, ) +from langchain.tools import ToolRuntime, tool from langchain_core.messages import ToolMessage from langgraph.types import Command from typing_extensions import NotRequired, TypedDict @@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence - from langchain.agents.middleware.types import ToolCallRequest # Tool type constants TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728" @@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware): self.allowed_prefixes = allowed_path_prefixes self.system_prompt = system_prompt + # Create tool that will be executed by the tool node + @tool(tool_name) + def file_tool( + runtime: ToolRuntime[None, AnthropicToolsState], + command: str, + path: str, + file_text: str | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + new_path: str | None = None, + view_range: list[int] | None = None, + ) -> Command | str: + """Execute file operations on virtual file system. + + Args: + runtime: Tool runtime providing access to state. + command: Operation to perform. + path: File path to operate on. + file_text: Full file content for create command. + old_str: String to replace for str_replace command. + new_str: Replacement string for str_replace command. + insert_line: Line number for insert command. + new_path: New path for rename command. + view_range: Line range [start, end] for view command. + + Returns: + Command for state update or string result. + """ + # Build args dict for handler methods + args: dict[str, Any] = {"path": path} + if file_text is not None: + args["file_text"] = file_text + if old_str is not None: + args["old_str"] = old_str + if new_str is not None: + args["new_str"] = new_str + if insert_line is not None: + args["insert_line"] = insert_line + if new_path is not None: + args["new_path"] = new_path + if view_range is not None: + args["view_range"] = view_range + + # Route to appropriate handler based on command + try: + if command == "view": + return self._handle_view(args, runtime.state, runtime.tool_call_id) + if command == "create": + return self._handle_create( + args, runtime.state, runtime.tool_call_id + ) + if command == "str_replace": + return self._handle_str_replace( + args, runtime.state, runtime.tool_call_id + ) + if command == "insert": + return self._handle_insert( + args, runtime.state, runtime.tool_call_id + ) + if command == "delete": + return self._handle_delete( + args, runtime.state, runtime.tool_call_id + ) + if command == "rename": + return self._handle_rename( + args, runtime.state, runtime.tool_call_id + ) + return f"Unknown command: {command}" + except (ValueError, FileNotFoundError) as e: + return str(e) + + self.tools = [file_tool] + def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: - """Inject tool and optional system prompt.""" - # Add tool - tools = list(request.tools or []) - tools.append( - { - "type": self.tool_type, - "name": self.tool_name, - } - ) - request.tools = tools + """Inject Anthropic tool descriptor and optional system prompt.""" + # Replace our BaseTool with Anthropic's native tool descriptor + tools = [ + t + for t in (request.tools or []) + if getattr(t, "name", None) != self.tool_name + ] + [{"type": self.tool_type, "name": self.tool_name}] # Inject system prompt if provided + overrides: _ModelRequestOverrides = {"tools": tools} if self.system_prompt: - request.system_prompt = ( + overrides["system_prompt"] = ( request.system_prompt + "\n\n" + self.system_prompt if request.system_prompt else self.system_prompt ) - return handler(request) + return handler(request.override(**overrides)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelResponse: - """Inject tool and optional system prompt (async version).""" - # Add tool - tools = list(request.tools or []) - tools.append( - { - "type": self.tool_type, - "name": self.tool_name, - } - ) - request.tools = tools + """Inject Anthropic tool descriptor and optional system prompt.""" + # Replace our BaseTool with Anthropic's native tool descriptor + tools = [ + t + for t in (request.tools or []) + if getattr(t, "name", None) != self.tool_name + ] + [{"type": self.tool_type, "name": self.tool_name}] # Inject system prompt if provided + overrides: _ModelRequestOverrides = {"tools": tools} if self.system_prompt: - request.system_prompt = ( + overrides["system_prompt"] = ( request.system_prompt + "\n\n" + self.system_prompt if request.system_prompt else self.system_prompt ) - return await handler(request) - - def wrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Intercept tool calls.""" - tool_call = request.tool_call - tool_name = tool_call.get("name") - - if tool_name != self.tool_name: - return handler(request) - - # Handle tool call - try: - args = tool_call.get("args", {}) - command = args.get("command") - state = request.state - - if command == "view": - return self._handle_view(args, state, tool_call["id"]) - if command == "create": - return self._handle_create(args, state, tool_call["id"]) - if command == "str_replace": - return self._handle_str_replace(args, state, tool_call["id"]) - if command == "insert": - return self._handle_insert(args, state, tool_call["id"]) - if command == "delete": - return self._handle_delete(args, state, tool_call["id"]) - if command == "rename": - return self._handle_rename(args, state, tool_call["id"]) - - msg = f"Unknown command: {command}" - return ToolMessage( - content=msg, - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - except (ValueError, FileNotFoundError) as e: - return ToolMessage( - content=str(e), - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - - async def awrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: - """Intercept tool calls (async version).""" - tool_call = request.tool_call - tool_name = tool_call.get("name") - - if tool_name != self.tool_name: - return await handler(request) - - # Handle tool call - try: - args = tool_call.get("args", {}) - command = args.get("command") - state = request.state - - if command == "view": - return self._handle_view(args, state, tool_call["id"]) - if command == "create": - return self._handle_create(args, state, tool_call["id"]) - if command == "str_replace": - return self._handle_str_replace(args, state, tool_call["id"]) - if command == "insert": - return self._handle_insert(args, state, tool_call["id"]) - if command == "delete": - return self._handle_delete(args, state, tool_call["id"]) - if command == "rename": - return self._handle_rename(args, state, tool_call["id"]) - - msg = f"Unknown command: {command}" - return ToolMessage( - content=msg, - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - except (ValueError, FileNotFoundError) as e: - return ToolMessage( - content=str(e), - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) + return await handler(request.override(**overrides)) def _handle_view( self, args: dict, state: AnthropicToolsState, tool_call_id: str | None @@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware): # Create root directory if it doesn't exist self.root_path.mkdir(parents=True, exist_ok=True) + # Create tool that will be executed by the tool node + @tool(tool_name) + def file_tool( + runtime: ToolRuntime, + command: str, + path: str, + file_text: str | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + new_path: str | None = None, + view_range: list[int] | None = None, + ) -> Command | str: + """Execute file operations on filesystem. + + Args: + runtime: Tool runtime providing tool_call_id. + command: Operation to perform. + path: File path to operate on. + file_text: Full file content for create command. + old_str: String to replace for str_replace command. + new_str: Replacement string for str_replace command. + insert_line: Line number for insert command. + new_path: New path for rename command. + view_range: Line range [start, end] for view command. + + Returns: + Command for message update or string result. + """ + # Build args dict for handler methods + args: dict[str, Any] = {"path": path} + if file_text is not None: + args["file_text"] = file_text + if old_str is not None: + args["old_str"] = old_str + if new_str is not None: + args["new_str"] = new_str + if insert_line is not None: + args["insert_line"] = insert_line + if new_path is not None: + args["new_path"] = new_path + if view_range is not None: + args["view_range"] = view_range + + # Route to appropriate handler based on command + try: + if command == "view": + return self._handle_view(args, runtime.tool_call_id) + if command == "create": + return self._handle_create(args, runtime.tool_call_id) + if command == "str_replace": + return self._handle_str_replace(args, runtime.tool_call_id) + if command == "insert": + return self._handle_insert(args, runtime.tool_call_id) + if command == "delete": + return self._handle_delete(args, runtime.tool_call_id) + if command == "rename": + return self._handle_rename(args, runtime.tool_call_id) + return f"Unknown command: {command}" + except (ValueError, FileNotFoundError, PermissionError) as e: + return str(e) + + self.tools = [file_tool] + def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: - """Inject tool and optional system prompt.""" - # Add tool - tools = list(request.tools or []) - tools.append( - { - "type": self.tool_type, - "name": self.tool_name, - } - ) - request.tools = tools + """Inject Anthropic tool descriptor and optional system prompt.""" + # Replace our BaseTool with Anthropic's native tool descriptor + tools = [ + t + for t in (request.tools or []) + if getattr(t, "name", None) != self.tool_name + ] + [{"type": self.tool_type, "name": self.tool_name}] # Inject system prompt if provided + overrides: _ModelRequestOverrides = {"tools": tools} if self.system_prompt: - request.system_prompt = ( + overrides["system_prompt"] = ( request.system_prompt + "\n\n" + self.system_prompt if request.system_prompt else self.system_prompt ) - return handler(request) + + return handler(request.override(**overrides)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelResponse: - """Inject tool and optional system prompt (async version).""" - # Add tool - tools = list(request.tools or []) - tools.append( - { - "type": self.tool_type, - "name": self.tool_name, - } - ) - request.tools = tools + """Inject Anthropic tool descriptor and optional system prompt.""" + # Replace our BaseTool with Anthropic's native tool descriptor + tools = [ + t + for t in (request.tools or []) + if getattr(t, "name", None) != self.tool_name + ] + [{"type": self.tool_type, "name": self.tool_name}] # Inject system prompt if provided + overrides: _ModelRequestOverrides = {"tools": tools} if self.system_prompt: - request.system_prompt = ( + overrides["system_prompt"] = ( request.system_prompt + "\n\n" + self.system_prompt if request.system_prompt else self.system_prompt ) - return await handler(request) - - def wrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: - """Intercept tool calls.""" - tool_call = request.tool_call - tool_name = tool_call.get("name") - - if tool_name != self.tool_name: - return handler(request) - - # Handle tool call - try: - args = tool_call.get("args", {}) - command = args.get("command") - - if command == "view": - return self._handle_view(args, tool_call["id"]) - if command == "create": - return self._handle_create(args, tool_call["id"]) - if command == "str_replace": - return self._handle_str_replace(args, tool_call["id"]) - if command == "insert": - return self._handle_insert(args, tool_call["id"]) - if command == "delete": - return self._handle_delete(args, tool_call["id"]) - if command == "rename": - return self._handle_rename(args, tool_call["id"]) - - msg = f"Unknown command: {command}" - return ToolMessage( - content=msg, - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - except (ValueError, FileNotFoundError) as e: - return ToolMessage( - content=str(e), - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - - async def awrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: - """Intercept tool calls (async version).""" - tool_call = request.tool_call - tool_name = tool_call.get("name") - - if tool_name != self.tool_name: - return await handler(request) - - # Handle tool call - try: - args = tool_call.get("args", {}) - command = args.get("command") - - if command == "view": - return self._handle_view(args, tool_call["id"]) - if command == "create": - return self._handle_create(args, tool_call["id"]) - if command == "str_replace": - return self._handle_str_replace(args, tool_call["id"]) - if command == "insert": - return self._handle_insert(args, tool_call["id"]) - if command == "delete": - return self._handle_delete(args, tool_call["id"]) - if command == "rename": - return self._handle_rename(args, tool_call["id"]) - - msg = f"Unknown command: {command}" - return ToolMessage( - content=msg, - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) - except (ValueError, FileNotFoundError) as e: - return ToolMessage( - content=str(e), - tool_call_id=tool_call["id"], - name=tool_name, - status="error", - ) + return await handler(request.override(**overrides)) def _validate_and_resolve_path(self, path: str) -> Path: """Validate and resolve a virtual path to filesystem path. diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py index 2f8ef0c3135..dd49ae4a774 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/bash.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/bash.py @@ -3,105 +3,81 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import Any, Literal +from typing import Any from langchain.agents.middleware.shell_tool import ShellToolMiddleware from langchain.agents.middleware.types import ( ModelRequest, ModelResponse, - ToolCallRequest, ) -from langchain_core.messages import ToolMessage -from langgraph.types import Command -_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"} +# Tool type constants for Anthropic +BASH_TOOL_TYPE = "bash_20250124" +BASH_TOOL_NAME = "bash" class ClaudeBashToolMiddleware(ShellToolMiddleware): """Middleware that exposes Anthropic's native bash tool to models.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize middleware without registering a client-side tool.""" - kwargs["shell_command"] = ("/bin/bash",) - super().__init__(*args, **kwargs) - # Remove the base tool so Claude's native descriptor is the sole entry. - self._tool = None # type: ignore[assignment] - self.tools = [] + def __init__( + self, + workspace_root: str | None = None, + *, + startup_commands: tuple[str, ...] | list[str] | str | None = None, + shutdown_commands: tuple[str, ...] | list[str] | str | None = None, + execution_policy: Any | None = None, + redaction_rules: tuple[Any, ...] | list[Any] | None = None, + tool_description: str | None = None, + env: dict[str, Any] | None = None, + ) -> None: + """Initialize middleware for Claude's native bash tool. + + Args: + workspace_root: Base directory for the shell session. + If omitted, a temporary directory is created. + startup_commands: Optional commands executed after the session starts. + shutdown_commands: Optional commands executed before session shutdown. + execution_policy: Execution policy controlling timeouts and limits. + redaction_rules: Optional redaction rules to sanitize output. + tool_description: Optional override for tool description. + env: Optional environment variables for the shell session. + """ + super().__init__( + workspace_root=workspace_root, + startup_commands=startup_commands, + shutdown_commands=shutdown_commands, + execution_policy=execution_policy, + redaction_rules=redaction_rules, + tool_description=tool_description, + tool_name=BASH_TOOL_NAME, + shell_command=("/bin/bash",), + env=env, + ) + # Parent class now creates the tool with name "bash" via tool_name parameter def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: - """Ensure the Claude bash descriptor is available to the model.""" - tools = request.tools - if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): - tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] - request = request.override(tools=tools) - return handler(request) + """Replace parent's shell tool with Claude's bash descriptor.""" + filtered = [ + t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME + ] + tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}] + return handler(request.override(tools=tools)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelResponse: - """Async: ensure the Claude bash descriptor is available to the model.""" - tools = request.tools - if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools): - tools = [*tools, _CLAUDE_BASH_DESCRIPTOR] - request = request.override(tools=tools) - return await handler(request) - - def wrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Command | ToolMessage], - ) -> Command | ToolMessage: - """Intercept Claude bash tool calls and execute them locally.""" - tool_call = request.tool_call - if tool_call.get("name") != "bash": - return handler(request) - resources = self._ensure_resources(request.state) - return self._run_shell_tool( - resources, - tool_call["args"], - tool_call_id=tool_call.get("id"), - ) - - async def awrap_tool_call( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]], - ) -> Command | ToolMessage: - """Async interception mirroring the synchronous implementation.""" - tool_call = request.tool_call - if tool_call.get("name") != "bash": - return await handler(request) - resources = self._ensure_resources(request.state) - return self._run_shell_tool( - resources, - tool_call["args"], - tool_call_id=tool_call.get("id"), - ) - - def _format_tool_message( - self, - content: str, - tool_call_id: str | None, - *, - status: Literal["success", "error"], - artifact: dict[str, Any] | None = None, - ) -> ToolMessage | str: - """Format tool responses using Claude's bash descriptor.""" - if tool_call_id is None: - return content - return ToolMessage( - content=content, - tool_call_id=tool_call_id, - name=_CLAUDE_BASH_DESCRIPTOR["name"], - status=status, - artifact=artifact or {}, - ) + """Async: replace parent's shell tool with Claude's bash descriptor.""" + filtered = [ + t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME + ] + tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}] + return await handler(request.override(tools=tools)) __all__ = ["ClaudeBashToolMiddleware"] diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py index ca47a241dfa..3d61acf7a9f 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_bash.py @@ -3,67 +3,59 @@ from __future__ import annotations from unittest.mock import MagicMock import pytest -from langchain_core.messages.tool import ToolCall pytest.importorskip( "anthropic", reason="Anthropic SDK is required for Claude middleware tests" ) -from langchain.agents.middleware.types import ToolCallRequest -from langchain_core.messages import ToolMessage - from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware -def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None: +def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that ClaudeBashToolMiddleware creates a tool named 'bash'.""" middleware = ClaudeBashToolMiddleware() - sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash") - monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel)) - monkeypatch.setattr( - middleware, "_ensure_resources", MagicMock(return_value=MagicMock()) + # Should have exactly one tool registered (from parent) + assert len(middleware.tools) == 1 + + # Tool is named "bash" (via tool_name parameter) + bash_tool = middleware.tools[0] + assert bash_tool.name == "bash" + + +def test_replaces_tool_with_claude_descriptor() -> None: + """Test wrap_model_call replaces bash tool with Claude's bash descriptor.""" + from langchain.agents.middleware.types import ModelRequest + + middleware = ClaudeBashToolMiddleware() + + # Create a mock request with the bash tool (inherited from parent) + bash_tool = middleware.tools[0] + request = ModelRequest( + model=MagicMock(), + system_prompt=None, + messages=[], + tool_choice=None, + tools=[bash_tool], + response_format=None, + state={"messages": []}, + runtime=MagicMock(), ) - tool_call: ToolCall = { + # Mock handler that captures the modified request + captured_request = None + + def handler(req: ModelRequest) -> MagicMock: + nonlocal captured_request + captured_request = req + return MagicMock() + + middleware.wrap_model_call(request, handler) + + # The bash tool should be replaced with Claude's native bash descriptor + assert captured_request is not None + assert len(captured_request.tools) == 1 + assert captured_request.tools[0] == { + "type": "bash_20250124", "name": "bash", - "args": {"command": "echo hi"}, - "id": "call-1", } - request = ToolCallRequest( - tool_call=tool_call, - tool=MagicMock(), - state={}, - runtime=None, # type: ignore[arg-type] - ) - - handler_called = False - - def handler(_: ToolCallRequest) -> ToolMessage: - nonlocal handler_called - handler_called = True - return ToolMessage(content="should not be used", tool_call_id="call-1") - - result = middleware.wrap_tool_call(request, handler) - assert result is sentinel - assert handler_called is False - - -def test_wrap_tool_call_passes_through_other_tools( - monkeypatch: pytest.MonkeyPatch, -) -> None: - middleware = ClaudeBashToolMiddleware() - tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"} - request = ToolCallRequest( - tool_call=tool_call, - tool=MagicMock(), - state={}, - runtime=None, # type: ignore[arg-type] - ) - - sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other") - - def handler(_: ToolCallRequest) -> ToolMessage: - return sentinel - - result = middleware.wrap_tool_call(request, handler) - assert result is sentinel