Compare commits

...

3 Commits

Author SHA1 Message Date
Sydney Runkle
b67cd71d7d release: langchain 1.0.8 (#34019) 2025-11-19 09:12:37 -05:00
Sydney Runkle
e150b7c7e3 release: langchain 1.0.7 (#33979)
support resumable (ensure or create) resources for shell middleware
2025-11-14 15:50:11 -05:00
Sydney Runkle
ee3fc91e7a fix: cherry picking fixes for langchain + langchain-anthropic releases (#33975)
Co-authored-by: ccurme <chester.curme@gmail.com>
2025-11-14 13:28:30 -05:00
19 changed files with 1236 additions and 541 deletions

View File

@@ -9,6 +9,8 @@ on:
paths:
- "libs/core/pyproject.toml"
- "libs/core/langchain_core/version.py"
- "libs/langchain_v1/pyproject.toml"
- "libs/langchain_v1/langchain/__init__.py"
permissions:
contents: read
@@ -20,7 +22,7 @@ jobs:
steps:
- uses: actions/checkout@v5
- name: "✅ Verify pyproject.toml & version.py Match"
- name: "✅ Verify pyproject.toml & version files Match"
run: |
# Check core versions
CORE_PYPROJECT_VERSION=$(grep -Po '(?<=^version = ")[^"]*' libs/core/pyproject.toml)

View File

@@ -1,3 +1,3 @@
"""Main entrypoint into LangChain."""
__version__ = "1.0.5"
__version__ = "1.0.8"

View File

@@ -4,6 +4,7 @@ from .context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from .file_search import FilesystemFileSearchMiddleware
from .human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
@@ -46,6 +47,7 @@ __all__ = [
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware",
"DockerExecutionPolicy",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
"InterruptOnConfig",

View File

@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)

View File

@@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment model call counts after a model call.
@@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
def after_model(
self,
state: AgentState,
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",

View File

@@ -11,17 +11,17 @@ import subprocess
import tempfile
import threading
import time
import typing
import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.base import ToolException
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired
from langchain.agents.middleware._execution import (
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
ResolvedRedactionRule,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
from langchain.tools import ToolRuntime, tool
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
"session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses."
)
SHELL_TOOL_NAME = "shell"
def _cleanup_resources(
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool."""
command: str | None = None
"""The shell command to execute."""
restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema()] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
injected ToolRuntime.
"""
@model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput:
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
return self
class _PersistentShellTool(BaseTool):
"""Tool wrapper that relies on middleware interception for execution."""
name: str = "shell"
description: str = DEFAULT_TOOL_DESCRIPTION
args_schema: type[BaseModel] = _ShellToolInput
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
super().__init__()
self._middleware = middleware
if description is not None:
self.description = description
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
raise RuntimeError(msg)
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Middleware that registers a persistent shell tool for agents.
@@ -393,6 +385,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None,
tool_name: str = SHELL_TOOL_NAME,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
) -> None:
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
returning it to the model.
tool_description: Optional override for the registered shell tool
description.
tool_name: Name for the registered shell tool.
Defaults to `"shell"`.
shell_command: Optional shell executable (string) or argument sequence used
to launch the persistent session.
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""
super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None
self._tool_name = tool_name
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
if execution_policy is not None:
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)
# Create a proper tool that executes directly (no interception needed)
description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool]
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
def shell_tool(
*,
runtime: ToolRuntime[None, ShellToolState],
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._get_or_create_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
tool_call_id=runtime.tool_call_id,
)
self._shell_tool = shell_tool
self.tools = [self._shell_tool]
@staticmethod
def _normalize_commands(
@@ -478,36 +491,48 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Start the shell session and run startup commands."""
resources = self._create_resources()
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
"""Run shutdown commands and release resources when an agent completes."""
resources = self._ensure_resources(state)
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
# Resources were never created, nothing to clean up
return
try:
self._run_shutdown_commands(resources.session)
finally:
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
"""Get existing resources from state or create new ones if they don't exist.
This method enables resumability by checking if resources already exist in the state
(e.g., after an interrupt), and only creating new resources if they're not present.
Args:
state: The agent state which may contain shell session resources.
Returns:
Session resources, either retrieved from state or newly created.
"""
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
)
raise ToolException(msg)
return resources
if isinstance(resources, _SessionResources):
return resources
new_resources = self._create_resources()
# Cast needed to make state dict-like for mutation
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
return new_resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
@@ -669,36 +694,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
artifact=artifact,
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept local shell tool calls and execute them via the managed session."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return handler(request)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return await handler(request)
def _format_tool_message(
self,
content: str,
@@ -713,7 +708,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=self._tool.name,
name=self._tool_name,
status=status,
artifact=artifact,
)

View File

@@ -451,3 +451,28 @@ class ToolCallLimitMiddleware(
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -9,10 +9,10 @@ license = { text = "MIT" }
readme = "README.md"
authors = []
version = "1.0.5"
version = "1.0.8"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"langchain-core>=1.0.4,<2.0.0",
"langchain-core>=1.0.6,<2.0.0",
"langgraph>=1.0.2,<1.1.0",
"pydantic>=2.7.4,<3.0.0",
]

View File

@@ -0,0 +1,556 @@
from __future__ import annotations
import asyncio
import gc
import tempfile
import time
from pathlib import Path
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.tools.base import ToolException
from langchain.agents.middleware.shell_tool import (
HostExecutionPolicy,
RedactionRule,
ShellToolMiddleware,
_SessionResources,
_ShellToolInput,
)
from langchain.agents.middleware.types import AgentState
def _empty_state() -> AgentState:
return {"messages": []} # type: ignore[return-value]
def test_executes_command_and_persists_state(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
assert isinstance(result, str)
assert result.strip() == "/"
echo_result = middleware._run_shell_tool(
resources, {"command": "echo ready"}, tool_call_id=None
)
assert "ready" in echo_result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_restart_resets_session_environment(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
restart_message = middleware._run_shell_tool(
resources, {"restart": True}, tool_call_id=None
)
assert "restarted" in restart_message.lower()
resources = middleware._get_or_create_resources(state) # reacquire after restart
result = middleware._run_shell_tool(
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
)
assert "unset" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_indicator_present(tmp_path: Path) -> None:
policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
assert "Output truncated" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_timeout_returns_error(tmp_path: Path) -> None:
policy = HostExecutionPolicy(command_timeout=0.5)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
start = time.monotonic()
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
elapsed = time.monotonic() - start
assert elapsed < policy.command_timeout + 2.0
assert "timed out" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_redaction_policy_applies(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
message = middleware._run_shell_tool(
resources,
{"command": "printf 'Contact: user@example.com\\n'"},
tool_call_id=None,
)
assert "[REDACTED_EMAIL]" in message
assert "user@example.com" not in message
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_and_shutdown_commands(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(
workspace_root=workspace,
startup_commands=("touch startup.txt",),
shutdown_commands=("touch shutdown.txt",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "startup.txt").exists()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "shutdown.txt").exists()
def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
policy = HostExecutionPolicy(termination_timeout=0.1)
class DummySession:
def __init__(self) -> None:
self.stopped: bool = False
def stop(self, timeout: float) -> None: # noqa: ARG002
self.stopped = True
session = DummySession()
tempdir = tempfile.TemporaryDirectory(dir=tmp_path)
tempdir_path = Path(tempdir.name)
resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type]
finalizer = resources._finalizer
# Drop our last strong reference and force collection.
del resources
gc.collect()
assert not finalizer.alive
assert session.stopped
assert not tempdir_path.exists()
def test_shell_tool_input_validation() -> None:
"""Test _ShellToolInput validation rules."""
# Both command and restart not allowed
with pytest.raises(ValueError, match="only one"):
_ShellToolInput(command="ls", restart=True)
# Neither command nor restart provided
with pytest.raises(ValueError, match="requires either"):
_ShellToolInput()
# Valid: command only
valid_cmd = _ShellToolInput(command="ls")
assert valid_cmd.command == "ls"
assert not valid_cmd.restart
# Valid: restart only
valid_restart = _ShellToolInput(restart=True)
assert valid_restart.restart is True
assert valid_restart.command is None
def test_normalize_shell_command_empty() -> None:
"""Test that empty shell command raises an error."""
with pytest.raises(ValueError, match="at least one argument"):
ShellToolMiddleware(shell_command=[])
def test_normalize_env_non_string_keys() -> None:
"""Test that non-string environment keys raise an error."""
with pytest.raises(TypeError, match="must be strings"):
ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item]
def test_normalize_env_coercion(tmp_path: Path) -> None:
"""Test that environment values are coerced to strings."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True}
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
)
assert "42" in result
assert "True" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
"""Test that shell tool raises an error when command is not a string."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(
resources,
{"command": 123}, # type: ignore[dict-item]
tool_call_id=None,
)
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
"""Test that tool messages are properly formatted with tool_call_id."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo test"}, tool_call_id="test-id-123"
)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "test-id-123"
assert result.name == "shell"
assert result.status == "success"
assert "test" in result.content
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
"""Test that non-zero exit codes are marked as errors."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "false"}, # Command that exits with 1 but doesn't kill shell
tool_call_id="test-id",
)
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "Exit code: 1" in result.content
assert result.artifact["exit_code"] == 1 # type: ignore[index]
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_by_bytes(tmp_path: Path) -> None:
"""Test that output is truncated by bytes when max_output_bytes is exceeded."""
policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
)
assert "truncated at 50 bytes" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_command_failure(tmp_path: Path) -> None:
"""Test that startup command failure raises an error."""
policy = HostExecutionPolicy(startup_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy
)
state: AgentState = _empty_state()
with pytest.raises(RuntimeError, match="Startup command.*failed"):
middleware.before_agent(state, None)
def test_shutdown_command_failure_logged(tmp_path: Path) -> None:
"""Test that shutdown command failures are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
shutdown_commands=("exit 1",),
execution_policy=policy,
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command failing
middleware.after_agent(state, None)
def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
"""Test that shutdown command timeouts are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=0.1)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
execution_policy=policy,
shutdown_commands=("sleep 2",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command timing out
middleware.after_agent(state, None)
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
"""Test that empty command output is replaced with '<no output>'."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "true"}, # Command that produces no output
tool_call_id=None,
)
assert "<no output>" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_stderr_output_labeling(tmp_path: Path) -> None:
"""Test that stderr output is properly labeled."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo error >&2"}, tool_call_id=None
)
assert "[stderr] error" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
@pytest.mark.parametrize(
("startup_commands", "expected"),
[
("echo test", ("echo test",)), # String
(["echo test", "pwd"], ("echo test", "pwd")), # List
(("echo test",), ("echo test",)), # Tuple
(None, ()), # None
],
)
def test_normalize_commands_string_tuple_list(
tmp_path: Path,
startup_commands: str | list[str] | tuple[str, ...] | None,
expected: tuple[str, ...],
) -> None:
"""Test various command normalization formats."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=startup_commands
)
assert middleware._startup_commands == expected # type: ignore[attr-defined]
def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
"""Test that async methods properly delegate to sync methods."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
# Test abefore_agent
updates = asyncio.run(middleware.abefore_agent(state, None))
if updates:
state.update(updates)
# Test aafter_agent
asyncio.run(middleware.aafter_agent(state, None))
finally:
pass
def test_shell_middleware_resumable_after_interrupt(tmp_path: Path) -> None:
"""Test that shell middleware is resumable after an interrupt.
This test simulates a scenario where:
1. The middleware creates a shell session
2. A command is executed
3. The agent is interrupted (state is preserved)
4. The agent resumes with the same state
5. The shell session is reused (not recreated)
"""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
# Simulate first execution (before interrupt)
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get the resources and verify they exist
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
initial_session = resources.session
initial_tempdir = resources.tempdir
# Execute a command to set state
middleware._run_shell_tool(resources, {"command": "export TEST_VAR=hello"}, tool_call_id=None)
# Simulate interrupt - state is preserved, but we don't call after_agent
# In a real scenario, the state would be checkpointed here
# Simulate resumption - call before_agent again with same state
# This should reuse existing resources, not create new ones
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get resources again - should be the same session
resumed_resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Verify the session was reused (same object reference)
assert resumed_resources.session is initial_session
assert resumed_resources.tempdir is initial_tempdir
# Verify the session state persisted (environment variable still set)
result = middleware._run_shell_tool(
resumed_resources, {"command": "echo ${TEST_VAR:-unset}"}, tool_call_id=None
)
assert "hello" in result
assert "unset" not in result
# Clean up
middleware.after_agent(state, None)
def test_get_or_create_resources_creates_when_missing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources creates resources when they don't exist."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# State has no resources initially
assert "shell_session_resources" not in state
# Call _get_or_create_resources - should create new resources
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert isinstance(resources, _SessionResources)
assert resources.session is not None
assert state.get("shell_session_resources") is resources
# Clean up
resources._finalizer()
def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources reuses existing resources."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# Create resources first time
resources1 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Call again - should return the same resources
resources2 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert resources1 is resources2
assert resources1.session is resources2.session
# Clean up
resources1._finalizer()

View File

@@ -28,7 +28,7 @@ def test_executes_command_and_persists_state(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
@@ -51,14 +51,14 @@ def test_restart_resets_session_environment(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
restart_message = middleware._run_shell_tool(
resources, {"restart": True}, tool_call_id=None
)
assert "restarted" in restart_message.lower()
resources = middleware._ensure_resources(state) # reacquire after restart
resources = middleware._get_or_create_resources(state) # reacquire after restart
result = middleware._run_shell_tool(
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
)
@@ -77,7 +77,7 @@ def test_truncation_indicator_present(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
assert "Output truncated" in result
finally:
@@ -94,7 +94,7 @@ def test_timeout_returns_error(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
start = time.monotonic()
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
elapsed = time.monotonic() - start
@@ -116,7 +116,7 @@ def test_redaction_policy_applies(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
message = middleware._run_shell_tool(
resources,
{"command": "printf 'Contact: user@example.com\\n'"},

View File

@@ -1788,7 +1788,7 @@ wheels = [
[[package]]
name = "langchain"
version = "1.0.5"
version = "1.0.8"
source = { editable = "." }
dependencies = [
{ name = "langchain-core" },

View File

@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
AgentState,
ModelRequest,
ModelResponse,
_ModelRequestOverrides,
)
from langchain.tools import ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence
from langchain.agents.middleware.types import ToolCallRequest
# Tool type constants
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
self.allowed_prefixes = allowed_path_prefixes
self.system_prompt = system_prompt
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime[None, AnthropicToolsState],
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on virtual file system.
Args:
runtime: Tool runtime providing access to state.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for state update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.state, runtime.tool_call_id)
if command == "create":
return self._handle_create(
args, runtime.state, runtime.tool_call_id
)
if command == "str_replace":
return self._handle_str_replace(
args, runtime.state, runtime.tool_call_id
)
if command == "insert":
return self._handle_insert(
args, runtime.state, runtime.tool_call_id
)
if command == "delete":
return self._handle_delete(
args, runtime.state, runtime.tool_call_id
)
if command == "rename":
return self._handle_rename(
args, runtime.state, runtime.tool_call_id
)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Inject tool and optional system prompt."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(**overrides))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Inject tool and optional system prompt (async version)."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
state = request.state
if command == "view":
return self._handle_view(args, state, tool_call["id"])
if command == "create":
return self._handle_create(args, state, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, state, tool_call["id"])
if command == "insert":
return self._handle_insert(args, state, tool_call["id"])
if command == "delete":
return self._handle_delete(args, state, tool_call["id"])
if command == "rename":
return self._handle_rename(args, state, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
return await handler(request.override(**overrides))
def _handle_view(
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
# Create root directory if it doesn't exist
self.root_path.mkdir(parents=True, exist_ok=True)
# Create tool that will be executed by the tool node
@tool(tool_name)
def file_tool(
runtime: ToolRuntime,
command: str,
path: str,
file_text: str | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
new_path: str | None = None,
view_range: list[int] | None = None,
) -> Command | str:
"""Execute file operations on filesystem.
Args:
runtime: Tool runtime providing tool_call_id.
command: Operation to perform.
path: File path to operate on.
file_text: Full file content for create command.
old_str: String to replace for str_replace command.
new_str: Replacement string for str_replace command.
insert_line: Line number for insert command.
new_path: New path for rename command.
view_range: Line range [start, end] for view command.
Returns:
Command for message update or string result.
"""
# Build args dict for handler methods
args: dict[str, Any] = {"path": path}
if file_text is not None:
args["file_text"] = file_text
if old_str is not None:
args["old_str"] = old_str
if new_str is not None:
args["new_str"] = new_str
if insert_line is not None:
args["insert_line"] = insert_line
if new_path is not None:
args["new_path"] = new_path
if view_range is not None:
args["view_range"] = view_range
# Route to appropriate handler based on command
try:
if command == "view":
return self._handle_view(args, runtime.tool_call_id)
if command == "create":
return self._handle_create(args, runtime.tool_call_id)
if command == "str_replace":
return self._handle_str_replace(args, runtime.tool_call_id)
if command == "insert":
return self._handle_insert(args, runtime.tool_call_id)
if command == "delete":
return self._handle_delete(args, runtime.tool_call_id)
if command == "rename":
return self._handle_rename(args, runtime.tool_call_id)
return f"Unknown command: {command}"
except (ValueError, FileNotFoundError, PermissionError) as e:
return str(e)
self.tools = [file_tool]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Inject tool and optional system prompt."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(**overrides))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Inject tool and optional system prompt (async version)."""
# Add tool
tools = list(request.tools or [])
tools.append(
{
"type": self.tool_type,
"name": self.tool_name,
}
)
request.tools = tools
"""Inject Anthropic tool descriptor and optional system prompt."""
# Replace our BaseTool with Anthropic's native tool descriptor
tools = [
t
for t in (request.tools or [])
if getattr(t, "name", None) != self.tool_name
] + [{"type": self.tool_type, "name": self.tool_name}]
# Inject system prompt if provided
overrides: _ModelRequestOverrides = {"tools": tools}
if self.system_prompt:
request.system_prompt = (
overrides["system_prompt"] = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept tool calls."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Intercept tool calls (async version)."""
tool_call = request.tool_call
tool_name = tool_call.get("name")
if tool_name != self.tool_name:
return await handler(request)
# Handle tool call
try:
args = tool_call.get("args", {})
command = args.get("command")
if command == "view":
return self._handle_view(args, tool_call["id"])
if command == "create":
return self._handle_create(args, tool_call["id"])
if command == "str_replace":
return self._handle_str_replace(args, tool_call["id"])
if command == "insert":
return self._handle_insert(args, tool_call["id"])
if command == "delete":
return self._handle_delete(args, tool_call["id"])
if command == "rename":
return self._handle_rename(args, tool_call["id"])
msg = f"Unknown command: {command}"
return ToolMessage(
content=msg,
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
except (ValueError, FileNotFoundError) as e:
return ToolMessage(
content=str(e),
tool_call_id=tool_call["id"],
name=tool_name,
status="error",
)
return await handler(request.override(**overrides))
def _validate_and_resolve_path(self, path: str) -> Path:
"""Validate and resolve a virtual path to filesystem path.

View File

@@ -3,93 +3,81 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from typing import Any
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.agents.middleware.types import (
ModelRequest,
ModelResponse,
ToolCallRequest,
)
from langchain_core.messages import ToolMessage
from langgraph.types import Command
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
# Tool type constants for Anthropic
BASH_TOOL_TYPE = "bash_20250124"
BASH_TOOL_NAME = "bash"
class ClaudeBashToolMiddleware(ShellToolMiddleware):
"""Middleware that exposes Anthropic's native bash tool to models."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize middleware without registering a client-side tool."""
kwargs["shell_command"] = ("/bin/bash",)
super().__init__(*args, **kwargs)
# Remove the base tool so Claude's native descriptor is the sole entry.
self._tool = None # type: ignore[assignment]
self.tools = []
def __init__(
self,
workspace_root: str | None = None,
*,
startup_commands: tuple[str, ...] | list[str] | str | None = None,
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
execution_policy: Any | None = None,
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
tool_description: str | None = None,
env: dict[str, Any] | None = None,
) -> None:
"""Initialize middleware for Claude's native bash tool.
Args:
workspace_root: Base directory for the shell session.
If omitted, a temporary directory is created.
startup_commands: Optional commands executed after the session starts.
shutdown_commands: Optional commands executed before session shutdown.
execution_policy: Execution policy controlling timeouts and limits.
redaction_rules: Optional redaction rules to sanitize output.
tool_description: Optional override for tool description.
env: Optional environment variables for the shell session.
"""
super().__init__(
workspace_root=workspace_root,
startup_commands=startup_commands,
shutdown_commands=shutdown_commands,
execution_policy=execution_policy,
redaction_rules=redaction_rules,
tool_description=tool_description,
tool_name=BASH_TOOL_NAME,
shell_command=("/bin/bash",),
env=env,
)
# Parent class now creates the tool with name "bash" via tool_name parameter
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return handler(request)
"""Replace parent's shell tool with Claude's bash descriptor."""
filtered = [
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
]
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return handler(request.override(tools=tools))
def wrap_tool_call(
async def awrap_model_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Command | ToolMessage],
) -> Command | ToolMessage:
"""Intercept Claude bash tool calls and execute them locally."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
) -> Command | ToolMessage:
"""Async interception mirroring the synchronous implementation."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return await handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
"""Format tool responses using Claude's bash descriptor."""
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=_CLAUDE_BASH_DESCRIPTOR["name"],
status=status,
artifact=artifact or {},
)
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Async: replace parent's shell tool with Claude's bash descriptor."""
filtered = [
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
]
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
return await handler(request.override(tools=tools))
__all__ = ["ClaudeBashToolMiddleware"]

View File

@@ -9,13 +9,13 @@ from __future__ import annotations
import fnmatch
import re
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Annotated, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
if TYPE_CHECKING:
from typing import Any
from langchain.agents.middleware.types import AgentMiddleware
from langchain_core.tools import InjectedToolArg, tool
from langchain.tools import ToolRuntime, tool
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
# Create tool instances
@tool
def glob_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast file pattern matching tool that works with any codebase size.
@@ -147,56 +147,17 @@ class StateFileSearchMiddleware(AgentMiddleware):
time (most recently modified first). Returns "No files found" if no
matches.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Get files from state
files = cast("dict[str, Any]", state.get(self.state_key, {}))
# Match files
matches = []
for file_path, file_data in files.items():
if file_path.startswith(base_path):
# Get relative path from base
if base_path == "/":
relative = file_path[1:] # Remove leading /
elif file_path == base_path:
relative = Path(file_path).name
elif file_path.startswith(base_path + "/"):
relative = file_path[len(base_path) + 1 :]
else:
continue
# Match against pattern
# Handle ** pattern which requires special care
# PurePosixPath.match doesn't match single-level paths
# against **/pattern
is_match = PurePosixPath(relative).match(pattern)
if not is_match and pattern.startswith("**/"):
# Also try matching without the **/ prefix for files in base dir
is_match = PurePosixPath(relative).match(pattern[3:])
if is_match:
matches.append((file_path, file_data["modified_at"]))
if not matches:
return "No files found"
# Sort by modification time
matches.sort(key=lambda x: x[1], reverse=True)
file_paths = [path for path, _ in matches]
return "\n".join(file_paths)
return self._handle_glob_search(pattern, path, runtime.state)
@tool
def grep_search( # noqa: D417
runtime: ToolRuntime[None, AnthropicToolsState],
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal[
"files_with_matches", "content", "count"
] = "files_with_matches",
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
) -> str:
"""Fast content search tool that works with any codebase size.
@@ -216,49 +177,133 @@ class StateFileSearchMiddleware(AgentMiddleware):
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Compile regex pattern (for validation)
try:
regex = re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Search files
files = cast("dict[str, Any]", state.get(self.state_key, {}))
results: dict[str, list[tuple[int, str]]] = {}
for file_path, file_data in files.items():
if not file_path.startswith(base_path):
continue
# Check include filter
if include:
basename = Path(file_path).name
if not _match_include_pattern(basename, include):
continue
# Search file content
for line_num, line in enumerate(file_data["content"], 1):
if regex.search(line):
if file_path not in results:
results[file_path] = []
results[file_path].append((line_num, line))
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
return self._handle_grep_search(
pattern, path, include, output_mode, runtime.state
)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _handle_glob_search(
self,
pattern: str,
path: str,
state: AnthropicToolsState,
) -> str:
"""Handle glob search operation.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in.
state: The current agent state.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns "No files found" if no
matches.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Get files from state
files = cast("dict[str, Any]", state.get(self.state_key, {}))
# Match files
matches = []
for file_path, file_data in files.items():
if file_path.startswith(base_path):
# Get relative path from base
if base_path == "/":
relative = file_path[1:] # Remove leading /
elif file_path == base_path:
relative = Path(file_path).name
elif file_path.startswith(base_path + "/"):
relative = file_path[len(base_path) + 1 :]
else:
continue
# Match against pattern
# Handle ** pattern which requires special care
# PurePosixPath.match doesn't match single-level paths
# against **/pattern
is_match = PurePosixPath(relative).match(pattern)
if not is_match and pattern.startswith("**/"):
# Also try matching without the **/ prefix for files in base dir
is_match = PurePosixPath(relative).match(pattern[3:])
if is_match:
matches.append((file_path, file_data["modified_at"]))
if not matches:
return "No files found"
# Sort by modification time
matches.sort(key=lambda x: x[1], reverse=True)
file_paths = [path for path, _ in matches]
return "\n".join(file_paths)
def _handle_grep_search(
self,
pattern: str,
path: str,
include: str | None,
output_mode: str,
state: AnthropicToolsState,
) -> str:
"""Handle grep search operation.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in.
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
output_mode: Output format.
state: The current agent state.
Returns:
Search results formatted according to output_mode. Returns "No matches
found" if no results.
"""
# Normalize base path
base_path = path if path.startswith("/") else "/" + path
# Compile regex pattern (for validation)
try:
regex = re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Search files
files = cast("dict[str, Any]", state.get(self.state_key, {}))
results: dict[str, list[tuple[int, str]]] = {}
for file_path, file_data in files.items():
if not file_path.startswith(base_path):
continue
# Check include filter
if include:
basename = Path(file_path).name
if not _match_include_pattern(basename, include):
continue
# Search file content
for line_num, line in enumerate(file_data["content"], 1):
if regex.search(line):
if file_path not in results:
results[file_path] = []
results[file_path].append((line_num, line))
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
def _format_grep_results(
self,
results: dict[str, list[tuple[int, str]]],

View File

@@ -9,11 +9,11 @@ license = { text = "MIT" }
readme = "README.md"
authors = []
version = "1.0.2"
version = "1.0.4"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"anthropic>=0.69.0,<1.0.0",
"langchain-core>=1.0.3,<2.0.0",
"langchain-core>=1.0.4,<2.0.0",
"pydantic>=2.7.4,<3.0.0",
]

View File

@@ -3,67 +3,59 @@ from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from langchain_core.messages.tool import ToolCall
pytest.importorskip(
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
)
from langchain.agents.middleware.types import ToolCallRequest
from langchain_core.messages import ToolMessage
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
middleware = ClaudeBashToolMiddleware()
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
monkeypatch.setattr(
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
# Should have exactly one tool registered (from parent)
assert len(middleware.tools) == 1
# Tool is named "bash" (via tool_name parameter)
bash_tool = middleware.tools[0]
assert bash_tool.name == "bash"
def test_replaces_tool_with_claude_descriptor() -> None:
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
from langchain.agents.middleware.types import ModelRequest
middleware = ClaudeBashToolMiddleware()
# Create a mock request with the bash tool (inherited from parent)
bash_tool = middleware.tools[0]
request = ModelRequest(
model=MagicMock(),
system_prompt=None,
messages=[],
tool_choice=None,
tools=[bash_tool],
response_format=None,
state={"messages": []},
runtime=MagicMock(),
)
tool_call: ToolCall = {
# Mock handler that captures the modified request
captured_request = None
def handler(req: ModelRequest) -> MagicMock:
nonlocal captured_request
captured_request = req
return MagicMock()
middleware.wrap_model_call(request, handler)
# The bash tool should be replaced with Claude's native bash descriptor
assert captured_request is not None
assert len(captured_request.tools) == 1
assert captured_request.tools[0] == {
"type": "bash_20250124",
"name": "bash",
"args": {"command": "echo hi"},
"id": "call-1",
}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
handler_called = False
def handler(_: ToolCallRequest) -> ToolMessage:
nonlocal handler_called
handler_called = True
return ToolMessage(content="should not be used", tool_call_id="call-1")
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel
assert handler_called is False
def test_wrap_tool_call_passes_through_other_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = ClaudeBashToolMiddleware()
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
def handler(_: ToolCallRequest) -> ToolMessage:
return sentinel
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel

View File

@@ -49,8 +49,10 @@ class TestGlobSearch:
},
}
# Call tool function directly (state is injected in real usage)
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
# Call internal handler method directly
result = middleware._handle_glob_search(
pattern="*.py", path="/", state=test_state
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -82,7 +84,9 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/", state=state
)
assert isinstance(result, str)
lines = result.split("\n")
@@ -109,7 +113,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func( # type: ignore[attr-defined]
result = middleware._handle_glob_search(
pattern="**/*.py", path="/src", state=state
)
@@ -132,7 +136,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
assert isinstance(result, str)
assert result == "No files found"
@@ -157,7 +161,7 @@ class TestGlobSearch:
},
}
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
lines = result.split("\n")
# Most recent first
@@ -193,7 +197,13 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -216,8 +226,12 @@ class TestGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def", include="*.{py", state=state
result = middleware._handle_grep_search(
pattern=r"def",
path="/",
include="*.{py",
output_mode="files_with_matches",
state=state,
)
assert result == "Invalid include pattern"
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"def \w+\(\):", output_mode="content", state=state
result = middleware._handle_grep_search(
pattern=r"def \w+\(\):",
path="/",
include=None,
output_mode="content",
state=state,
)
assert isinstance(result, str)
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern=r"TODO", output_mode="count", state=state
result = middleware._handle_grep_search(
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
)
assert isinstance(result, str)
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="import", include="*.py", state=state
result = middleware._handle_grep_search(
pattern="import",
path="/",
include="*.py",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func( # type: ignore[attr-defined]
pattern="const", include="*.{ts,tsx}", state=state
result = middleware._handle_grep_search(
pattern="const",
path="/",
include="*.{ts,tsx}",
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern="import",
path="/src",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert result == "No matches found"
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
"text_editor_files": {},
}
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"[unclosed",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "Invalid regex pattern" in result
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r"TODO",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
},
}
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
result = middleware._handle_grep_search(
pattern=r".*",
path="/",
include=None,
output_mode="files_with_matches",
state=state,
)
assert isinstance(result, str)
assert "/src/main.py" in result

View File

@@ -21,7 +21,7 @@ wheels = [
[[package]]
name = "anthropic"
version = "0.71.0"
version = "0.72.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -33,9 +33,9 @@ dependencies = [
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/82/4f/70682b068d897841f43223df82d96ec1d617435a8b759c4a2d901a50158b/anthropic-0.71.0.tar.gz", hash = "sha256:eb8e6fa86d049061b3ef26eb4cbae0174ebbff21affa6de7b3098da857d8de6a", size = 489102, upload-time = "2025-10-16T15:54:40.08Z" }
sdist = { url = "https://files.pythonhosted.org/packages/dd/f3/feb750a21461090ecf48bbebcaa261cd09003cc1d14e2fa9643ad59edd4d/anthropic-0.72.1.tar.gz", hash = "sha256:a6d1d660e1f4af91dddc732f340786d19acaffa1ae8e69442e56be5fa6539d51", size = 415395, upload-time = "2025-11-11T16:53:29.001Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/77/073e8ac488f335aec7001952825275582fb8f433737e90f24eeef9d878f6/anthropic-0.71.0-py3-none-any.whl", hash = "sha256:85c5015fcdbdc728390f11b17642a65a4365d03b12b799b18b6cc57e71fdb327", size = 355035, upload-time = "2025-10-16T15:54:38.238Z" },
{ url = "https://files.pythonhosted.org/packages/51/05/d9d45edad1aa28330cea09a3b35e1590f7279f91bb5ab5237c70a0884ea3/anthropic-0.72.1-py3-none-any.whl", hash = "sha256:81e73cca55e8924776c8c4418003defe6bf9eaf0cd92beb94c8dbf537b95316f", size = 357373, upload-time = "2025-11-11T16:53:27.438Z" },
]
[[package]]
@@ -477,7 +477,7 @@ wheels = [
[[package]]
name = "langchain"
version = "1.0.4"
version = "1.0.6"
source = { editable = "../../langchain_v1" }
dependencies = [
{ name = "langchain-core" },
@@ -541,7 +541,7 @@ typing = [
[[package]]
name = "langchain-anthropic"
version = "1.0.2"
version = "1.0.4"
source = { editable = "." }
dependencies = [
{ name = "anthropic" },