mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(anthropic): execute bash + file tools via tool node (#33960)
* use `override` instead of directly patching things on `ModelRequest` * rely on `ToolNode` for execution of tools related to said middleware, using `wrap_model_call` to inject the relevant claude tool specs + allowing tool node to forward them along to corresponding langchain tool implementations * making the same change for the native shell tool middleware * allowing shell tool middleware to specify a name for the shell tool (negative diff then for claude bash middleware) long term I think the solution might be to attach metadata to a tool to map the provider spec to a langchain implementation, which we could also take some lessons from on the MCP front.
This commit is contained in:
@@ -11,7 +11,6 @@ import subprocess
|
|||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import typing
|
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -19,9 +18,10 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||||
|
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools.base import BaseTool, ToolException
|
from langchain_core.tools.base import ToolException
|
||||||
from langgraph.channels.untracked_value import UntrackedValue
|
from langgraph.channels.untracked_value import UntrackedValue
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
from pydantic.json_schema import SkipJsonSchema
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from langchain.agents.middleware._execution import (
|
from langchain.agents.middleware._execution import (
|
||||||
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
|||||||
ResolvedRedactionRule,
|
ResolvedRedactionRule,
|
||||||
)
|
)
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||||||
|
from langchain.tools import ToolRuntime, tool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
from langchain.agents.middleware.types import ToolCallRequest
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||||
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
|||||||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||||
"running commands will be terminated once their configured timeout elapses."
|
"running commands will be terminated once their configured timeout elapses."
|
||||||
)
|
)
|
||||||
|
SHELL_TOOL_NAME = "shell"
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_resources(
|
def _cleanup_resources(
|
||||||
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
|
|||||||
"""Input schema for the persistent shell tool."""
|
"""Input schema for the persistent shell tool."""
|
||||||
|
|
||||||
command: str | None = None
|
command: str | None = None
|
||||||
|
"""The shell command to execute."""
|
||||||
|
|
||||||
restart: bool | None = None
|
restart: bool | None = None
|
||||||
|
"""Whether to restart the shell session."""
|
||||||
|
|
||||||
|
runtime: Annotated[Any, SkipJsonSchema] = None
|
||||||
|
"""The runtime for the shell tool.
|
||||||
|
|
||||||
|
Included as a workaround at the moment bc args_schema doesn't work with
|
||||||
|
injected ToolRuntime.
|
||||||
|
"""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_payload(self) -> _ShellToolInput:
|
def validate_payload(self) -> _ShellToolInput:
|
||||||
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class _PersistentShellTool(BaseTool):
|
|
||||||
"""Tool wrapper that relies on middleware interception for execution."""
|
|
||||||
|
|
||||||
name: str = "shell"
|
|
||||||
description: str = DEFAULT_TOOL_DESCRIPTION
|
|
||||||
args_schema: type[BaseModel] = _ShellToolInput
|
|
||||||
|
|
||||||
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._middleware = middleware
|
|
||||||
if description is not None:
|
|
||||||
self.description = description
|
|
||||||
|
|
||||||
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
|
||||||
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||||
"""Middleware that registers a persistent shell tool for agents.
|
"""Middleware that registers a persistent shell tool for agents.
|
||||||
|
|
||||||
@@ -393,6 +385,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
execution_policy: BaseExecutionPolicy | None = None,
|
execution_policy: BaseExecutionPolicy | None = None,
|
||||||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||||
tool_description: str | None = None,
|
tool_description: str | None = None,
|
||||||
|
tool_name: str = SHELL_TOOL_NAME,
|
||||||
shell_command: Sequence[str] | str | None = None,
|
shell_command: Sequence[str] | str | None = None,
|
||||||
env: Mapping[str, Any] | None = None,
|
env: Mapping[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
returning it to the model.
|
returning it to the model.
|
||||||
tool_description: Optional override for the registered shell tool
|
tool_description: Optional override for the registered shell tool
|
||||||
description.
|
description.
|
||||||
|
tool_name: Name for the registered shell tool.
|
||||||
|
|
||||||
|
Defaults to `"shell"`.
|
||||||
shell_command: Optional shell executable (string) or argument sequence used
|
shell_command: Optional shell executable (string) or argument sequence used
|
||||||
to launch the persistent session.
|
to launch the persistent session.
|
||||||
|
|
||||||
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||||
|
self._tool_name = tool_name
|
||||||
self._shell_command = self._normalize_shell_command(shell_command)
|
self._shell_command = self._normalize_shell_command(shell_command)
|
||||||
self._environment = self._normalize_env(env)
|
self._environment = self._normalize_env(env)
|
||||||
if execution_policy is not None:
|
if execution_policy is not None:
|
||||||
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
self._startup_commands = self._normalize_commands(startup_commands)
|
self._startup_commands = self._normalize_commands(startup_commands)
|
||||||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||||
|
|
||||||
|
# Create a proper tool that executes directly (no interception needed)
|
||||||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||||
self._tool = _PersistentShellTool(self, description=description)
|
|
||||||
self.tools = [self._tool]
|
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||||||
|
def shell_tool(
|
||||||
|
*,
|
||||||
|
runtime: ToolRuntime[None, ShellToolState],
|
||||||
|
command: str | None = None,
|
||||||
|
restart: bool = False,
|
||||||
|
) -> ToolMessage | str:
|
||||||
|
resources = self._ensure_resources(runtime.state)
|
||||||
|
return self._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
{"command": command, "restart": restart},
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._shell_tool = shell_tool
|
||||||
|
self.tools = [self._shell_tool]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_commands(
|
def _normalize_commands(
|
||||||
@@ -669,37 +682,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
artifact=artifact,
|
artifact=artifact,
|
||||||
)
|
)
|
||||||
|
|
||||||
def wrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Intercept local shell tool calls and execute them via the managed session."""
|
|
||||||
if isinstance(request.tool, _PersistentShellTool):
|
|
||||||
resources = self._ensure_resources(request.state)
|
|
||||||
return self._run_shell_tool(
|
|
||||||
resources,
|
|
||||||
request.tool_call["args"],
|
|
||||||
tool_call_id=request.tool_call.get("id"),
|
|
||||||
)
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
async def awrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Async intercept local shell tool calls and execute them via the managed session."""
|
|
||||||
# The sync version already handles all the work, no need for async-specific logic
|
|
||||||
if isinstance(request.tool, _PersistentShellTool):
|
|
||||||
resources = self._ensure_resources(request.state)
|
|
||||||
return self._run_shell_tool(
|
|
||||||
resources,
|
|
||||||
request.tool_call["args"],
|
|
||||||
tool_call_id=request.tool_call.get("id"),
|
|
||||||
)
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
def _format_tool_message(
|
def _format_tool_message(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
@@ -714,7 +696,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=content,
|
content=content,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
name=self._tool.name,
|
name=self._tool_name,
|
||||||
status=status,
|
status=status,
|
||||||
artifact=artifact,
|
artifact=artifact,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
|
|||||||
AgentState,
|
AgentState,
|
||||||
ModelRequest,
|
ModelRequest,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
_ModelRequestOverrides,
|
||||||
)
|
)
|
||||||
|
from langchain.tools import ToolRuntime, tool
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable, Callable, Sequence
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
|
|
||||||
from langchain.agents.middleware.types import ToolCallRequest
|
|
||||||
|
|
||||||
# Tool type constants
|
# Tool type constants
|
||||||
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
|
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
|
||||||
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
|
|||||||
self.allowed_prefixes = allowed_path_prefixes
|
self.allowed_prefixes = allowed_path_prefixes
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
# Create tool that will be executed by the tool node
|
||||||
|
@tool(tool_name)
|
||||||
|
def file_tool(
|
||||||
|
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||||
|
command: str,
|
||||||
|
path: str,
|
||||||
|
file_text: str | None = None,
|
||||||
|
old_str: str | None = None,
|
||||||
|
new_str: str | None = None,
|
||||||
|
insert_line: int | None = None,
|
||||||
|
new_path: str | None = None,
|
||||||
|
view_range: list[int] | None = None,
|
||||||
|
) -> Command | str:
|
||||||
|
"""Execute file operations on virtual file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runtime: Tool runtime providing access to state.
|
||||||
|
command: Operation to perform.
|
||||||
|
path: File path to operate on.
|
||||||
|
file_text: Full file content for create command.
|
||||||
|
old_str: String to replace for str_replace command.
|
||||||
|
new_str: Replacement string for str_replace command.
|
||||||
|
insert_line: Line number for insert command.
|
||||||
|
new_path: New path for rename command.
|
||||||
|
view_range: Line range [start, end] for view command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command for state update or string result.
|
||||||
|
"""
|
||||||
|
# Build args dict for handler methods
|
||||||
|
args: dict[str, Any] = {"path": path}
|
||||||
|
if file_text is not None:
|
||||||
|
args["file_text"] = file_text
|
||||||
|
if old_str is not None:
|
||||||
|
args["old_str"] = old_str
|
||||||
|
if new_str is not None:
|
||||||
|
args["new_str"] = new_str
|
||||||
|
if insert_line is not None:
|
||||||
|
args["insert_line"] = insert_line
|
||||||
|
if new_path is not None:
|
||||||
|
args["new_path"] = new_path
|
||||||
|
if view_range is not None:
|
||||||
|
args["view_range"] = view_range
|
||||||
|
|
||||||
|
# Route to appropriate handler based on command
|
||||||
|
try:
|
||||||
|
if command == "view":
|
||||||
|
return self._handle_view(args, runtime.state, runtime.tool_call_id)
|
||||||
|
if command == "create":
|
||||||
|
return self._handle_create(
|
||||||
|
args, runtime.state, runtime.tool_call_id
|
||||||
|
)
|
||||||
|
if command == "str_replace":
|
||||||
|
return self._handle_str_replace(
|
||||||
|
args, runtime.state, runtime.tool_call_id
|
||||||
|
)
|
||||||
|
if command == "insert":
|
||||||
|
return self._handle_insert(
|
||||||
|
args, runtime.state, runtime.tool_call_id
|
||||||
|
)
|
||||||
|
if command == "delete":
|
||||||
|
return self._handle_delete(
|
||||||
|
args, runtime.state, runtime.tool_call_id
|
||||||
|
)
|
||||||
|
if command == "rename":
|
||||||
|
return self._handle_rename(
|
||||||
|
args, runtime.state, runtime.tool_call_id
|
||||||
|
)
|
||||||
|
return f"Unknown command: {command}"
|
||||||
|
except (ValueError, FileNotFoundError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
self.tools = [file_tool]
|
||||||
|
|
||||||
def wrap_model_call(
|
def wrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Inject tool and optional system prompt."""
|
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||||
# Add tool
|
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||||
tools = list(request.tools or [])
|
tools = [
|
||||||
tools.append(
|
t
|
||||||
{
|
for t in (request.tools or [])
|
||||||
"type": self.tool_type,
|
if getattr(t, "name", None) != self.tool_name
|
||||||
"name": self.tool_name,
|
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||||
}
|
|
||||||
)
|
|
||||||
request.tools = tools
|
|
||||||
|
|
||||||
# Inject system prompt if provided
|
# Inject system prompt if provided
|
||||||
|
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
request.system_prompt = (
|
overrides["system_prompt"] = (
|
||||||
request.system_prompt + "\n\n" + self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
if request.system_prompt
|
if request.system_prompt
|
||||||
else self.system_prompt
|
else self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
return handler(request)
|
return handler(request.override(**overrides))
|
||||||
|
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Inject tool and optional system prompt (async version)."""
|
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||||
# Add tool
|
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||||
tools = list(request.tools or [])
|
tools = [
|
||||||
tools.append(
|
t
|
||||||
{
|
for t in (request.tools or [])
|
||||||
"type": self.tool_type,
|
if getattr(t, "name", None) != self.tool_name
|
||||||
"name": self.tool_name,
|
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||||
}
|
|
||||||
)
|
|
||||||
request.tools = tools
|
|
||||||
|
|
||||||
# Inject system prompt if provided
|
# Inject system prompt if provided
|
||||||
|
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
request.system_prompt = (
|
overrides["system_prompt"] = (
|
||||||
request.system_prompt + "\n\n" + self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
if request.system_prompt
|
if request.system_prompt
|
||||||
else self.system_prompt
|
else self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
return await handler(request)
|
return await handler(request.override(**overrides))
|
||||||
|
|
||||||
def wrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Intercept tool calls."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
tool_name = tool_call.get("name")
|
|
||||||
|
|
||||||
if tool_name != self.tool_name:
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
# Handle tool call
|
|
||||||
try:
|
|
||||||
args = tool_call.get("args", {})
|
|
||||||
command = args.get("command")
|
|
||||||
state = request.state
|
|
||||||
|
|
||||||
if command == "view":
|
|
||||||
return self._handle_view(args, state, tool_call["id"])
|
|
||||||
if command == "create":
|
|
||||||
return self._handle_create(args, state, tool_call["id"])
|
|
||||||
if command == "str_replace":
|
|
||||||
return self._handle_str_replace(args, state, tool_call["id"])
|
|
||||||
if command == "insert":
|
|
||||||
return self._handle_insert(args, state, tool_call["id"])
|
|
||||||
if command == "delete":
|
|
||||||
return self._handle_delete(args, state, tool_call["id"])
|
|
||||||
if command == "rename":
|
|
||||||
return self._handle_rename(args, state, tool_call["id"])
|
|
||||||
|
|
||||||
msg = f"Unknown command: {command}"
|
|
||||||
return ToolMessage(
|
|
||||||
content=msg,
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
except (ValueError, FileNotFoundError) as e:
|
|
||||||
return ToolMessage(
|
|
||||||
content=str(e),
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def awrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Intercept tool calls (async version)."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
tool_name = tool_call.get("name")
|
|
||||||
|
|
||||||
if tool_name != self.tool_name:
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
# Handle tool call
|
|
||||||
try:
|
|
||||||
args = tool_call.get("args", {})
|
|
||||||
command = args.get("command")
|
|
||||||
state = request.state
|
|
||||||
|
|
||||||
if command == "view":
|
|
||||||
return self._handle_view(args, state, tool_call["id"])
|
|
||||||
if command == "create":
|
|
||||||
return self._handle_create(args, state, tool_call["id"])
|
|
||||||
if command == "str_replace":
|
|
||||||
return self._handle_str_replace(args, state, tool_call["id"])
|
|
||||||
if command == "insert":
|
|
||||||
return self._handle_insert(args, state, tool_call["id"])
|
|
||||||
if command == "delete":
|
|
||||||
return self._handle_delete(args, state, tool_call["id"])
|
|
||||||
if command == "rename":
|
|
||||||
return self._handle_rename(args, state, tool_call["id"])
|
|
||||||
|
|
||||||
msg = f"Unknown command: {command}"
|
|
||||||
return ToolMessage(
|
|
||||||
content=msg,
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
except (ValueError, FileNotFoundError) as e:
|
|
||||||
return ToolMessage(
|
|
||||||
content=str(e),
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_view(
|
def _handle_view(
|
||||||
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
|
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
|
||||||
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
|
|||||||
# Create root directory if it doesn't exist
|
# Create root directory if it doesn't exist
|
||||||
self.root_path.mkdir(parents=True, exist_ok=True)
|
self.root_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create tool that will be executed by the tool node
|
||||||
|
@tool(tool_name)
|
||||||
|
def file_tool(
|
||||||
|
runtime: ToolRuntime,
|
||||||
|
command: str,
|
||||||
|
path: str,
|
||||||
|
file_text: str | None = None,
|
||||||
|
old_str: str | None = None,
|
||||||
|
new_str: str | None = None,
|
||||||
|
insert_line: int | None = None,
|
||||||
|
new_path: str | None = None,
|
||||||
|
view_range: list[int] | None = None,
|
||||||
|
) -> Command | str:
|
||||||
|
"""Execute file operations on filesystem.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runtime: Tool runtime providing tool_call_id.
|
||||||
|
command: Operation to perform.
|
||||||
|
path: File path to operate on.
|
||||||
|
file_text: Full file content for create command.
|
||||||
|
old_str: String to replace for str_replace command.
|
||||||
|
new_str: Replacement string for str_replace command.
|
||||||
|
insert_line: Line number for insert command.
|
||||||
|
new_path: New path for rename command.
|
||||||
|
view_range: Line range [start, end] for view command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command for message update or string result.
|
||||||
|
"""
|
||||||
|
# Build args dict for handler methods
|
||||||
|
args: dict[str, Any] = {"path": path}
|
||||||
|
if file_text is not None:
|
||||||
|
args["file_text"] = file_text
|
||||||
|
if old_str is not None:
|
||||||
|
args["old_str"] = old_str
|
||||||
|
if new_str is not None:
|
||||||
|
args["new_str"] = new_str
|
||||||
|
if insert_line is not None:
|
||||||
|
args["insert_line"] = insert_line
|
||||||
|
if new_path is not None:
|
||||||
|
args["new_path"] = new_path
|
||||||
|
if view_range is not None:
|
||||||
|
args["view_range"] = view_range
|
||||||
|
|
||||||
|
# Route to appropriate handler based on command
|
||||||
|
try:
|
||||||
|
if command == "view":
|
||||||
|
return self._handle_view(args, runtime.tool_call_id)
|
||||||
|
if command == "create":
|
||||||
|
return self._handle_create(args, runtime.tool_call_id)
|
||||||
|
if command == "str_replace":
|
||||||
|
return self._handle_str_replace(args, runtime.tool_call_id)
|
||||||
|
if command == "insert":
|
||||||
|
return self._handle_insert(args, runtime.tool_call_id)
|
||||||
|
if command == "delete":
|
||||||
|
return self._handle_delete(args, runtime.tool_call_id)
|
||||||
|
if command == "rename":
|
||||||
|
return self._handle_rename(args, runtime.tool_call_id)
|
||||||
|
return f"Unknown command: {command}"
|
||||||
|
except (ValueError, FileNotFoundError, PermissionError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
self.tools = [file_tool]
|
||||||
|
|
||||||
def wrap_model_call(
|
def wrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Inject tool and optional system prompt."""
|
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||||
# Add tool
|
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||||
tools = list(request.tools or [])
|
tools = [
|
||||||
tools.append(
|
t
|
||||||
{
|
for t in (request.tools or [])
|
||||||
"type": self.tool_type,
|
if getattr(t, "name", None) != self.tool_name
|
||||||
"name": self.tool_name,
|
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||||
}
|
|
||||||
)
|
|
||||||
request.tools = tools
|
|
||||||
|
|
||||||
# Inject system prompt if provided
|
# Inject system prompt if provided
|
||||||
|
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
request.system_prompt = (
|
overrides["system_prompt"] = (
|
||||||
request.system_prompt + "\n\n" + self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
if request.system_prompt
|
if request.system_prompt
|
||||||
else self.system_prompt
|
else self.system_prompt
|
||||||
)
|
)
|
||||||
return handler(request)
|
|
||||||
|
return handler(request.override(**overrides))
|
||||||
|
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Inject tool and optional system prompt (async version)."""
|
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||||
# Add tool
|
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||||
tools = list(request.tools or [])
|
tools = [
|
||||||
tools.append(
|
t
|
||||||
{
|
for t in (request.tools or [])
|
||||||
"type": self.tool_type,
|
if getattr(t, "name", None) != self.tool_name
|
||||||
"name": self.tool_name,
|
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||||
}
|
|
||||||
)
|
|
||||||
request.tools = tools
|
|
||||||
|
|
||||||
# Inject system prompt if provided
|
# Inject system prompt if provided
|
||||||
|
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||||
if self.system_prompt:
|
if self.system_prompt:
|
||||||
request.system_prompt = (
|
overrides["system_prompt"] = (
|
||||||
request.system_prompt + "\n\n" + self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
if request.system_prompt
|
if request.system_prompt
|
||||||
else self.system_prompt
|
else self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
return await handler(request)
|
return await handler(request.override(**overrides))
|
||||||
|
|
||||||
def wrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Intercept tool calls."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
tool_name = tool_call.get("name")
|
|
||||||
|
|
||||||
if tool_name != self.tool_name:
|
|
||||||
return handler(request)
|
|
||||||
|
|
||||||
# Handle tool call
|
|
||||||
try:
|
|
||||||
args = tool_call.get("args", {})
|
|
||||||
command = args.get("command")
|
|
||||||
|
|
||||||
if command == "view":
|
|
||||||
return self._handle_view(args, tool_call["id"])
|
|
||||||
if command == "create":
|
|
||||||
return self._handle_create(args, tool_call["id"])
|
|
||||||
if command == "str_replace":
|
|
||||||
return self._handle_str_replace(args, tool_call["id"])
|
|
||||||
if command == "insert":
|
|
||||||
return self._handle_insert(args, tool_call["id"])
|
|
||||||
if command == "delete":
|
|
||||||
return self._handle_delete(args, tool_call["id"])
|
|
||||||
if command == "rename":
|
|
||||||
return self._handle_rename(args, tool_call["id"])
|
|
||||||
|
|
||||||
msg = f"Unknown command: {command}"
|
|
||||||
return ToolMessage(
|
|
||||||
content=msg,
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
except (ValueError, FileNotFoundError) as e:
|
|
||||||
return ToolMessage(
|
|
||||||
content=str(e),
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def awrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Intercept tool calls (async version)."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
tool_name = tool_call.get("name")
|
|
||||||
|
|
||||||
if tool_name != self.tool_name:
|
|
||||||
return await handler(request)
|
|
||||||
|
|
||||||
# Handle tool call
|
|
||||||
try:
|
|
||||||
args = tool_call.get("args", {})
|
|
||||||
command = args.get("command")
|
|
||||||
|
|
||||||
if command == "view":
|
|
||||||
return self._handle_view(args, tool_call["id"])
|
|
||||||
if command == "create":
|
|
||||||
return self._handle_create(args, tool_call["id"])
|
|
||||||
if command == "str_replace":
|
|
||||||
return self._handle_str_replace(args, tool_call["id"])
|
|
||||||
if command == "insert":
|
|
||||||
return self._handle_insert(args, tool_call["id"])
|
|
||||||
if command == "delete":
|
|
||||||
return self._handle_delete(args, tool_call["id"])
|
|
||||||
if command == "rename":
|
|
||||||
return self._handle_rename(args, tool_call["id"])
|
|
||||||
|
|
||||||
msg = f"Unknown command: {command}"
|
|
||||||
return ToolMessage(
|
|
||||||
content=msg,
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
except (ValueError, FileNotFoundError) as e:
|
|
||||||
return ToolMessage(
|
|
||||||
content=str(e),
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
name=tool_name,
|
|
||||||
status="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_and_resolve_path(self, path: str) -> Path:
|
def _validate_and_resolve_path(self, path: str) -> Path:
|
||||||
"""Validate and resolve a virtual path to filesystem path.
|
"""Validate and resolve a virtual path to filesystem path.
|
||||||
|
|||||||
@@ -3,105 +3,81 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
||||||
from langchain.agents.middleware.types import (
|
from langchain.agents.middleware.types import (
|
||||||
ModelRequest,
|
ModelRequest,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ToolCallRequest,
|
|
||||||
)
|
)
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
|
# Tool type constants for Anthropic
|
||||||
|
BASH_TOOL_TYPE = "bash_20250124"
|
||||||
|
BASH_TOOL_NAME = "bash"
|
||||||
|
|
||||||
|
|
||||||
class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||||
"""Middleware that exposes Anthropic's native bash tool to models."""
|
"""Middleware that exposes Anthropic's native bash tool to models."""
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(
|
||||||
"""Initialize middleware without registering a client-side tool."""
|
self,
|
||||||
kwargs["shell_command"] = ("/bin/bash",)
|
workspace_root: str | None = None,
|
||||||
super().__init__(*args, **kwargs)
|
*,
|
||||||
# Remove the base tool so Claude's native descriptor is the sole entry.
|
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||||
self._tool = None # type: ignore[assignment]
|
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||||
self.tools = []
|
execution_policy: Any | None = None,
|
||||||
|
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
|
||||||
|
tool_description: str | None = None,
|
||||||
|
env: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize middleware for Claude's native bash tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Base directory for the shell session.
|
||||||
|
If omitted, a temporary directory is created.
|
||||||
|
startup_commands: Optional commands executed after the session starts.
|
||||||
|
shutdown_commands: Optional commands executed before session shutdown.
|
||||||
|
execution_policy: Execution policy controlling timeouts and limits.
|
||||||
|
redaction_rules: Optional redaction rules to sanitize output.
|
||||||
|
tool_description: Optional override for tool description.
|
||||||
|
env: Optional environment variables for the shell session.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
workspace_root=workspace_root,
|
||||||
|
startup_commands=startup_commands,
|
||||||
|
shutdown_commands=shutdown_commands,
|
||||||
|
execution_policy=execution_policy,
|
||||||
|
redaction_rules=redaction_rules,
|
||||||
|
tool_description=tool_description,
|
||||||
|
tool_name=BASH_TOOL_NAME,
|
||||||
|
shell_command=("/bin/bash",),
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
# Parent class now creates the tool with name "bash" via tool_name parameter
|
||||||
|
|
||||||
def wrap_model_call(
|
def wrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Ensure the Claude bash descriptor is available to the model."""
|
"""Replace parent's shell tool with Claude's bash descriptor."""
|
||||||
tools = request.tools
|
filtered = [
|
||||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
]
|
||||||
request = request.override(tools=tools)
|
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||||
return handler(request)
|
return handler(request.override(tools=tools))
|
||||||
|
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
self,
|
self,
|
||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""Async: ensure the Claude bash descriptor is available to the model."""
|
"""Async: replace parent's shell tool with Claude's bash descriptor."""
|
||||||
tools = request.tools
|
filtered = [
|
||||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
]
|
||||||
request = request.override(tools=tools)
|
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||||
return await handler(request)
|
return await handler(request.override(tools=tools))
|
||||||
|
|
||||||
def wrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], Command | ToolMessage],
|
|
||||||
) -> Command | ToolMessage:
|
|
||||||
"""Intercept Claude bash tool calls and execute them locally."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
if tool_call.get("name") != "bash":
|
|
||||||
return handler(request)
|
|
||||||
resources = self._ensure_resources(request.state)
|
|
||||||
return self._run_shell_tool(
|
|
||||||
resources,
|
|
||||||
tool_call["args"],
|
|
||||||
tool_call_id=tool_call.get("id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def awrap_tool_call(
|
|
||||||
self,
|
|
||||||
request: ToolCallRequest,
|
|
||||||
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
|
|
||||||
) -> Command | ToolMessage:
|
|
||||||
"""Async interception mirroring the synchronous implementation."""
|
|
||||||
tool_call = request.tool_call
|
|
||||||
if tool_call.get("name") != "bash":
|
|
||||||
return await handler(request)
|
|
||||||
resources = self._ensure_resources(request.state)
|
|
||||||
return self._run_shell_tool(
|
|
||||||
resources,
|
|
||||||
tool_call["args"],
|
|
||||||
tool_call_id=tool_call.get("id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_tool_message(
|
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
tool_call_id: str | None,
|
|
||||||
*,
|
|
||||||
status: Literal["success", "error"],
|
|
||||||
artifact: dict[str, Any] | None = None,
|
|
||||||
) -> ToolMessage | str:
|
|
||||||
"""Format tool responses using Claude's bash descriptor."""
|
|
||||||
if tool_call_id is None:
|
|
||||||
return content
|
|
||||||
return ToolMessage(
|
|
||||||
content=content,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
name=_CLAUDE_BASH_DESCRIPTOR["name"],
|
|
||||||
status=status,
|
|
||||||
artifact=artifact or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeBashToolMiddleware"]
|
__all__ = ["ClaudeBashToolMiddleware"]
|
||||||
|
|||||||
@@ -3,67 +3,59 @@ from __future__ import annotations
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages.tool import ToolCall
|
|
||||||
|
|
||||||
pytest.importorskip(
|
pytest.importorskip(
|
||||||
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
|
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.agents.middleware.types import ToolCallRequest
|
|
||||||
from langchain_core.messages import ToolMessage
|
|
||||||
|
|
||||||
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
||||||
|
|
||||||
|
|
||||||
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
|
||||||
middleware = ClaudeBashToolMiddleware()
|
middleware = ClaudeBashToolMiddleware()
|
||||||
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
|
||||||
|
|
||||||
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
|
# Should have exactly one tool registered (from parent)
|
||||||
monkeypatch.setattr(
|
assert len(middleware.tools) == 1
|
||||||
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
|
|
||||||
|
# Tool is named "bash" (via tool_name parameter)
|
||||||
|
bash_tool = middleware.tools[0]
|
||||||
|
assert bash_tool.name == "bash"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replaces_tool_with_claude_descriptor() -> None:
|
||||||
|
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
|
||||||
|
from langchain.agents.middleware.types import ModelRequest
|
||||||
|
|
||||||
|
middleware = ClaudeBashToolMiddleware()
|
||||||
|
|
||||||
|
# Create a mock request with the bash tool (inherited from parent)
|
||||||
|
bash_tool = middleware.tools[0]
|
||||||
|
request = ModelRequest(
|
||||||
|
model=MagicMock(),
|
||||||
|
system_prompt=None,
|
||||||
|
messages=[],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[bash_tool],
|
||||||
|
response_format=None,
|
||||||
|
state={"messages": []},
|
||||||
|
runtime=MagicMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call: ToolCall = {
|
# Mock handler that captures the modified request
|
||||||
|
captured_request = None
|
||||||
|
|
||||||
|
def handler(req: ModelRequest) -> MagicMock:
|
||||||
|
nonlocal captured_request
|
||||||
|
captured_request = req
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
middleware.wrap_model_call(request, handler)
|
||||||
|
|
||||||
|
# The bash tool should be replaced with Claude's native bash descriptor
|
||||||
|
assert captured_request is not None
|
||||||
|
assert len(captured_request.tools) == 1
|
||||||
|
assert captured_request.tools[0] == {
|
||||||
|
"type": "bash_20250124",
|
||||||
"name": "bash",
|
"name": "bash",
|
||||||
"args": {"command": "echo hi"},
|
|
||||||
"id": "call-1",
|
|
||||||
}
|
}
|
||||||
request = ToolCallRequest(
|
|
||||||
tool_call=tool_call,
|
|
||||||
tool=MagicMock(),
|
|
||||||
state={},
|
|
||||||
runtime=None, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
handler_called = False
|
|
||||||
|
|
||||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
|
||||||
nonlocal handler_called
|
|
||||||
handler_called = True
|
|
||||||
return ToolMessage(content="should not be used", tool_call_id="call-1")
|
|
||||||
|
|
||||||
result = middleware.wrap_tool_call(request, handler)
|
|
||||||
assert result is sentinel
|
|
||||||
assert handler_called is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_wrap_tool_call_passes_through_other_tools(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
middleware = ClaudeBashToolMiddleware()
|
|
||||||
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
|
|
||||||
request = ToolCallRequest(
|
|
||||||
tool_call=tool_call,
|
|
||||||
tool=MagicMock(),
|
|
||||||
state={},
|
|
||||||
runtime=None, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
|
|
||||||
|
|
||||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
|
||||||
return sentinel
|
|
||||||
|
|
||||||
result = middleware.wrap_tool_call(request, handler)
|
|
||||||
assert result is sentinel
|
|
||||||
|
|||||||
Reference in New Issue
Block a user