feat(langchain_v1): Add ShellToolMiddleware and ClaudeBashToolMiddleware (#33527)

- Both middleware share the same implementation, the only difference is
one uses Claude's server-side tool definition, whereas the other one
uses a generic tool definition compatible with all models
- Implemented 3 execution policies (responsible for actually running the
shell process)
- HostExecutionPolicy runs the shell as subprocess, appropriate for
already sandboxed environments, eg when run inside a dedicated docker
container
- CodexSandboxExecutionPolicy runs the shell using the sandbox command
from the Codex CLI which implements sandboxing techniques for Linux and
Mac OS.
- DockerExecutionPolicy runs the shell inside a dedicated Docker
container for isolation.
- Implements all behaviours described in
https://docs.claude.com/en/docs/agents-and-tools/tool-use/bash-tool#handle-large-outputs
including timeouts, truncation, output redaction, etc

---------

Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Nuno Campos
2025-10-17 03:32:11 +01:00
committed by GitHub
parent e0e11423d9
commit a022e3c14d
10 changed files with 2250 additions and 477 deletions

View File

@@ -17,6 +17,13 @@ from .human_in_the_loop import (
from .model_call_limit import ModelCallLimitMiddleware from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware from .model_fallback import ModelFallbackMiddleware
from .pii import PIIDetectionError, PIIMiddleware from .pii import PIIDetectionError, PIIMiddleware
from .shell_tool import (
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
HostExecutionPolicy,
RedactionRule,
ShellToolMiddleware,
)
from .summarization import SummarizationMiddleware from .summarization import SummarizationMiddleware
from .todo import TodoListMiddleware from .todo import TodoListMiddleware
from .tool_call_limit import ToolCallLimitMiddleware from .tool_call_limit import ToolCallLimitMiddleware
@@ -42,7 +49,10 @@ __all__ = [
"AgentMiddleware", "AgentMiddleware",
"AgentState", "AgentState",
"ClearToolUsesEdit", "ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware", "ContextEditingMiddleware",
"DockerExecutionPolicy",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware", "HumanInTheLoopMiddleware",
"InterruptOnConfig", "InterruptOnConfig",
"LLMToolEmulator", "LLMToolEmulator",
@@ -53,6 +63,8 @@ __all__ = [
"ModelResponse", "ModelResponse",
"PIIDetectionError", "PIIDetectionError",
"PIIMiddleware", "PIIMiddleware",
"RedactionRule",
"ShellToolMiddleware",
"SummarizationMiddleware", "SummarizationMiddleware",
"TodoListMiddleware", "TodoListMiddleware",
"ToolCallLimitMiddleware", "ToolCallLimitMiddleware",

View File

@@ -0,0 +1,388 @@
"""Execution policies for the persistent shell middleware."""
from __future__ import annotations
import abc
import json
import os
import shutil
import subprocess
import sys
import typing
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
try: # pragma: no cover - optional dependency on POSIX platforms
import resource
except ImportError: # pragma: no cover - non-POSIX systems
resource = None # type: ignore[assignment]
SHELL_TEMP_PREFIX = "langchain-shell-"
def _launch_subprocess(
command: Sequence[str],
*,
env: Mapping[str, str],
cwd: Path,
preexec_fn: typing.Callable[[], None] | None,
start_new_session: bool,
) -> subprocess.Popen[str]:
return subprocess.Popen( # noqa: S603
list(command),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
env=env,
preexec_fn=preexec_fn, # noqa: PLW1509
start_new_session=start_new_session,
)
if typing.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from pathlib import Path
@dataclass
class BaseExecutionPolicy(abc.ABC):
"""Configuration contract for persistent shell sessions.
Concrete subclasses encapsulate how a shell process is launched and constrained.
Each policy documents its security guarantees and the operating environments in
which it is appropriate. Use :class:`HostExecutionPolicy` for trusted, same-host
execution; :class:`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is
available and you want additional syscall restrictions; and
:class:`DockerExecutionPolicy` for container-level isolation using Docker.
"""
command_timeout: float = 30.0
startup_timeout: float = 30.0
termination_timeout: float = 10.0
max_output_lines: int = 100
max_output_bytes: int | None = None
def __post_init__(self) -> None:
if self.max_output_lines <= 0:
msg = "max_output_lines must be positive."
raise ValueError(msg)
@abc.abstractmethod
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
"""Launch the persistent shell process."""
@dataclass
class HostExecutionPolicy(BaseExecutionPolicy):
"""Run the shell directly on the host process.
This policy is best suited for trusted or single-tenant environments (CI jobs,
developer workstations, pre-sandboxed containers) where the agent must access the
host filesystem and tooling without additional isolation. It enforces optional CPU
and memory limits to prevent runaway commands but offers **no** filesystem or network
sandboxing; commands can modify anything the process user can reach.
On Linux platforms resource limits are applied with ``resource.prlimit`` after the
shell starts. On macOS, where ``prlimit`` is unavailable, limits are set in a
``preexec_fn`` before ``exec``. In both cases the shell runs in its own process group
so timeouts can terminate the full subtree.
"""
cpu_time_seconds: int | None = None
memory_bytes: int | None = None
create_process_group: bool = True
_limits_requested: bool = field(init=False, repr=False, default=False)
def __post_init__(self) -> None:
super().__post_init__()
if self.cpu_time_seconds is not None and self.cpu_time_seconds <= 0:
msg = "cpu_time_seconds must be positive if provided."
raise ValueError(msg)
if self.memory_bytes is not None and self.memory_bytes <= 0:
msg = "memory_bytes must be positive if provided."
raise ValueError(msg)
self._limits_requested = any(
value is not None for value in (self.cpu_time_seconds, self.memory_bytes)
)
if self._limits_requested and resource is None:
msg = (
"HostExecutionPolicy cpu/memory limits require the Python 'resource' module. "
"Either remove the limits or run on a POSIX platform."
)
raise RuntimeError(msg)
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
process = _launch_subprocess(
list(command),
env=env,
cwd=workspace,
preexec_fn=self._create_preexec_fn(),
start_new_session=self.create_process_group,
)
self._apply_post_spawn_limits(process)
return process
def _create_preexec_fn(self) -> typing.Callable[[], None] | None:
if not self._limits_requested or self._can_use_prlimit():
return None
def _configure() -> None: # pragma: no cover - depends on OS
if self.cpu_time_seconds is not None:
limit = (self.cpu_time_seconds, self.cpu_time_seconds)
resource.setrlimit(resource.RLIMIT_CPU, limit)
if self.memory_bytes is not None:
limit = (self.memory_bytes, self.memory_bytes)
if hasattr(resource, "RLIMIT_AS"):
resource.setrlimit(resource.RLIMIT_AS, limit)
elif hasattr(resource, "RLIMIT_DATA"):
resource.setrlimit(resource.RLIMIT_DATA, limit)
return _configure
def _apply_post_spawn_limits(self, process: subprocess.Popen[str]) -> None:
if not self._limits_requested or not self._can_use_prlimit():
return
if resource is None: # pragma: no cover - defensive
return
pid = process.pid
if pid is None:
return
try:
prlimit = typing.cast("typing.Any", resource).prlimit
if self.cpu_time_seconds is not None:
prlimit(pid, resource.RLIMIT_CPU, (self.cpu_time_seconds, self.cpu_time_seconds))
if self.memory_bytes is not None:
limit = (self.memory_bytes, self.memory_bytes)
if hasattr(resource, "RLIMIT_AS"):
prlimit(pid, resource.RLIMIT_AS, limit)
elif hasattr(resource, "RLIMIT_DATA"):
prlimit(pid, resource.RLIMIT_DATA, limit)
except OSError as exc: # pragma: no cover - depends on platform support
msg = "Failed to apply resource limits via prlimit."
raise RuntimeError(msg) from exc
@staticmethod
def _can_use_prlimit() -> bool:
return (
resource is not None
and hasattr(resource, "prlimit")
and sys.platform.startswith("linux")
)
@dataclass
class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
"""Launch the shell through the Codex CLI sandbox.
Ideal when you have the Codex CLI installed and want the additional syscall and
filesystem restrictions provided by Anthropic's Seatbelt (macOS) or Landlock/seccomp
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
the CLI. If the Codex binary is unavailable or the runtime lacks the required
kernel features (e.g., Landlock inside some containers), process startup fails with a
:class:`RuntimeError`.
Configure sandbox behaviour via ``config_overrides`` to align with your Codex CLI
profile. This policy does not add its own resource limits; combine it with
host-level guards (cgroups, container resource limits) as needed.
"""
binary: str = "codex"
platform: typing.Literal["auto", "macos", "linux"] = "auto"
config_overrides: Mapping[str, typing.Any] = field(default_factory=dict)
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
full_command = self._build_command(command)
return _launch_subprocess(
full_command,
env=env,
cwd=workspace,
preexec_fn=None,
start_new_session=False,
)
def _build_command(self, command: Sequence[str]) -> list[str]:
binary = self._resolve_binary()
platform_arg = self._determine_platform()
full_command: list[str] = [binary, "sandbox", platform_arg]
for key, value in sorted(dict(self.config_overrides).items()):
full_command.extend(["-c", f"{key}={self._format_override(value)}"])
full_command.append("--")
full_command.extend(command)
return full_command
def _resolve_binary(self) -> str:
path = shutil.which(self.binary)
if path is None:
msg = (
"Codex sandbox policy requires the '%s' CLI to be installed and available on PATH."
)
raise RuntimeError(msg % self.binary)
return path
def _determine_platform(self) -> str:
if self.platform != "auto":
return self.platform
if sys.platform.startswith("linux"):
return "linux"
if sys.platform == "darwin":
return "macos"
msg = (
"Codex sandbox policy could not determine a supported platform; "
"set 'platform' explicitly."
)
raise RuntimeError(msg)
@staticmethod
def _format_override(value: typing.Any) -> str:
try:
return json.dumps(value)
except TypeError:
return str(value)
@dataclass
class DockerExecutionPolicy(BaseExecutionPolicy):
"""Run the shell inside a dedicated Docker container.
Choose this policy when commands originate from untrusted users or you require
strong isolation between sessions. By default the workspace is bind-mounted only when
it refers to an existing non-temporary directory; ephemeral sessions run without a
mount to minimise host exposure. The container's network namespace is disabled by
default (``--network none``) and you can enable further hardening via
``read_only_rootfs`` and ``user``.
The security guarantees depend on your Docker daemon configuration. Run the agent on
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and review
any additional volumes or capabilities passed through ``extra_run_args``. The default
image is ``python:3.12-alpine3.19``; supply a custom image if you need preinstalled
tooling.
"""
binary: str = "docker"
image: str = "python:3.12-alpine3.19"
remove_container_on_exit: bool = True
network_enabled: bool = False
extra_run_args: Sequence[str] | None = None
memory_bytes: int | None = None
cpu_time_seconds: typing.Any | None = None
cpus: str | None = None
read_only_rootfs: bool = False
user: str | None = None
def __post_init__(self) -> None:
super().__post_init__()
if self.memory_bytes is not None and self.memory_bytes <= 0:
msg = "memory_bytes must be positive if provided."
raise ValueError(msg)
if self.cpu_time_seconds is not None:
msg = (
"DockerExecutionPolicy does not support cpu_time_seconds; configure CPU limits "
"using Docker run options such as '--cpus'."
)
raise RuntimeError(msg)
if self.cpus is not None and not self.cpus.strip():
msg = "cpus must be a non-empty string when provided."
raise ValueError(msg)
if self.user is not None and not self.user.strip():
msg = "user must be a non-empty string when provided."
raise ValueError(msg)
self.extra_run_args = tuple(self.extra_run_args or ())
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
full_command = self._build_command(workspace, env, command)
host_env = os.environ.copy()
return _launch_subprocess(
full_command,
env=host_env,
cwd=workspace,
preexec_fn=None,
start_new_session=False,
)
def _build_command(
self,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> list[str]:
binary = self._resolve_binary()
full_command: list[str] = [binary, "run", "-i"]
if self.remove_container_on_exit:
full_command.append("--rm")
if not self.network_enabled:
full_command.extend(["--network", "none"])
if self.memory_bytes is not None:
full_command.extend(["--memory", str(self.memory_bytes)])
if self._should_mount_workspace(workspace):
host_path = str(workspace)
full_command.extend(["-v", f"{host_path}:{host_path}"])
full_command.extend(["-w", host_path])
else:
full_command.extend(["-w", "/"])
if self.read_only_rootfs:
full_command.append("--read-only")
for key, value in env.items():
full_command.extend(["-e", f"{key}={value}"])
if self.cpus is not None:
full_command.extend(["--cpus", self.cpus])
if self.user is not None:
full_command.extend(["--user", self.user])
if self.extra_run_args:
full_command.extend(self.extra_run_args)
full_command.append(self.image)
full_command.extend(command)
return full_command
@staticmethod
def _should_mount_workspace(workspace: Path) -> bool:
return not workspace.name.startswith(SHELL_TEMP_PREFIX)
def _resolve_binary(self) -> str:
path = shutil.which(self.binary)
if path is None:
msg = (
"Docker execution policy requires the '%s' CLI to be installed"
" and available on PATH."
)
raise RuntimeError(msg % self.binary)
return path
__all__ = [
"BaseExecutionPolicy",
"CodexSandboxExecutionPolicy",
"DockerExecutionPolicy",
"HostExecutionPolicy",
]

View File

@@ -0,0 +1,350 @@
"""Shared redaction utilities for middleware components."""
from __future__ import annotations
import hashlib
import ipaddress
import re
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse
from typing_extensions import TypedDict
RedactionStrategy = Literal["block", "redact", "mask", "hash"]
"""Supported strategies for handling detected sensitive values."""
class PIIMatch(TypedDict):
"""Represents an individual match of sensitive data."""
type: str
value: str
start: int
end: int
class PIIDetectionError(Exception):
"""Raised when configured to block on detected sensitive values."""
def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None:
"""Initialize the exception with match context.
Args:
pii_type: Name of the detected sensitive type.
matches: All matches that were detected for that type.
"""
self.pii_type = pii_type
self.matches = list(matches)
count = len(matches)
msg = f"Detected {count} instance(s) of {pii_type} in text content"
super().__init__(msg)
Detector = Callable[[str], list[PIIMatch]]
"""Callable signature for detectors that locate sensitive values."""
def detect_email(content: str) -> list[PIIMatch]:
"""Detect email addresses in content."""
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
return [
PIIMatch(
type="email",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_credit_card(content: str) -> list[PIIMatch]:
"""Detect credit card numbers in content using Luhn validation."""
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
matches = []
for match in re.finditer(pattern, content):
card_number = match.group()
if _passes_luhn(card_number):
matches.append(
PIIMatch(
type="credit_card",
value=card_number,
start=match.start(),
end=match.end(),
)
)
return matches
def detect_ip(content: str) -> list[PIIMatch]:
"""Detect IPv4 or IPv6 addresses in content."""
matches: list[PIIMatch] = []
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
for match in re.finditer(ipv4_pattern, content):
ip_candidate = match.group()
try:
ipaddress.ip_address(ip_candidate)
except ValueError:
continue
matches.append(
PIIMatch(
type="ip",
value=ip_candidate,
start=match.start(),
end=match.end(),
)
)
return matches
def detect_mac_address(content: str) -> list[PIIMatch]:
"""Detect MAC addresses in content."""
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
return [
PIIMatch(
type="mac_address",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_url(content: str) -> list[PIIMatch]:
"""Detect URLs in content using regex and stdlib validation."""
matches: list[PIIMatch] = []
# Pattern 1: URLs with scheme (http:// or https://)
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
for match in re.finditer(scheme_pattern, content):
url = match.group()
result = urlparse(url)
if result.scheme in ("http", "https") and result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=match.start(),
end=match.end(),
)
)
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
# More conservative to avoid false positives
bare_pattern = (
r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?"
r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?"
)
for match in re.finditer(bare_pattern, content):
start, end = match.start(), match.end()
# Skip if already matched with scheme
if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches):
continue
url = match.group()
# Only accept if it has a path or starts with www
# This reduces false positives like "example.com" in prose
if "/" in url or url.startswith("www."):
# Add scheme for validation (required for urlparse to work correctly)
test_url = f"http://{url}"
result = urlparse(test_url)
if result.netloc and "." in result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=start,
end=end,
)
)
return matches
BUILTIN_DETECTORS: dict[str, Detector] = {
"email": detect_email,
"credit_card": detect_credit_card,
"ip": detect_ip,
"mac_address": detect_mac_address,
"url": detect_url,
}
"""Registry of built-in detectors keyed by type name."""
def _passes_luhn(card_number: str) -> bool:
"""Validate credit card number using the Luhn checksum."""
digits = [int(d) for d in card_number if d.isdigit()]
if not 13 <= len(digits) <= 19:
return False
checksum = 0
for index, digit in enumerate(reversed(digits)):
value = digit
if index % 2 == 1:
value *= 2
if value > 9:
value -= 9
checksum += value
return checksum % 10 == 0
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
replacement = f"[REDACTED_{match['type'].upper()}]"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
value = match["value"]
pii_type = match["type"]
if pii_type == "email":
parts = value.split("@")
if len(parts) == 2:
domain_parts = parts[1].split(".")
masked = (
f"{parts[0]}@****.{domain_parts[-1]}"
if len(domain_parts) >= 2
else f"{parts[0]}@****"
)
else:
masked = "****"
elif pii_type == "credit_card":
digits_only = "".join(c for c in value if c.isdigit())
separator = "-" if "-" in value else " " if " " in value else ""
if separator:
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
else:
masked = f"************{digits_only[-4:]}"
elif pii_type == "ip":
octets = value.split(".")
masked = f"*.*.*.{octets[-1]}" if len(octets) == 4 else "****"
elif pii_type == "mac_address":
separator = ":" if ":" in value else "-"
masked = (
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
)
elif pii_type == "url":
masked = "[MASKED_URL]"
else:
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
result = result[: match["start"]] + masked + result[match["end"] :]
return result
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
replacement = f"<{match['type']}_hash:{digest}>"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
def apply_strategy(
content: str,
matches: list[PIIMatch],
strategy: RedactionStrategy,
) -> str:
"""Apply the configured strategy to matches within content."""
if not matches:
return content
if strategy == "redact":
return _apply_redact_strategy(content, matches)
if strategy == "mask":
return _apply_mask_strategy(content, matches)
if strategy == "hash":
return _apply_hash_strategy(content, matches)
if strategy == "block":
raise PIIDetectionError(matches[0]["type"], matches)
msg = f"Unknown redaction strategy: {strategy}"
raise ValueError(msg)
def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector:
"""Return a callable detector for the given configuration."""
if detector is None:
if pii_type not in BUILTIN_DETECTORS:
msg = (
f"Unknown PII type: {pii_type}. "
f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector."
)
raise ValueError(msg)
return BUILTIN_DETECTORS[pii_type]
if isinstance(detector, str):
pattern = re.compile(detector)
def regex_detector(content: str) -> list[PIIMatch]:
return [
PIIMatch(
type=pii_type,
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in pattern.finditer(content)
]
return regex_detector
return detector
@dataclass(frozen=True)
class RedactionRule:
"""Configuration for handling a single PII type."""
pii_type: str
strategy: RedactionStrategy = "redact"
detector: Detector | str | None = None
def resolve(self) -> ResolvedRedactionRule:
"""Resolve runtime detector and return an immutable rule."""
resolved_detector = resolve_detector(self.pii_type, self.detector)
return ResolvedRedactionRule(
pii_type=self.pii_type,
strategy=self.strategy,
detector=resolved_detector,
)
@dataclass(frozen=True)
class ResolvedRedactionRule:
"""Resolved redaction rule ready for execution."""
pii_type: str
strategy: RedactionStrategy
detector: Detector
def apply(self, content: str) -> tuple[str, list[PIIMatch]]:
"""Apply this rule to content, returning new content and matches."""
matches = self.detector(content)
if not matches:
return content, []
updated = apply_strategy(content, matches, self.strategy)
return updated, matches
__all__ = [
"PIIDetectionError",
"PIIMatch",
"RedactionRule",
"ResolvedRedactionRule",
"apply_strategy",
"detect_credit_card",
"detect_email",
"detect_ip",
"detect_mac_address",
"detect_url",
]

View File

@@ -2,15 +2,22 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import ipaddress
import re
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import urlparse
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from typing_extensions import TypedDict
from langchain.agents.middleware._redaction import (
PIIDetectionError,
PIIMatch,
RedactionRule,
ResolvedRedactionRule,
apply_strategy,
detect_credit_card,
detect_email,
detect_ip,
detect_mac_address,
detect_url,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -19,396 +26,6 @@ if TYPE_CHECKING:
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
class PIIMatch(TypedDict):
"""Represents a detected PII match in text."""
type: str
"""The type of PII detected (e.g., 'email', 'ssn', 'credit_card')."""
value: str
"""The actual matched text."""
start: int
"""Starting position of the match in the text."""
end: int
"""Ending position of the match in the text."""
class PIIDetectionError(Exception):
"""Exception raised when PII is detected and strategy is 'block'."""
def __init__(self, pii_type: str, matches: list[PIIMatch]) -> None:
"""Initialize the exception with PII detection information.
Args:
pii_type: The type of PII that was detected.
matches: List of PII matches found.
"""
self.pii_type = pii_type
self.matches = matches
count = len(matches)
msg = f"Detected {count} instance(s) of {pii_type} in message content"
super().__init__(msg)
# ============================================================================
# PII Detection Functions
# ============================================================================
def _luhn_checksum(card_number: str) -> bool:
"""Validate credit card number using Luhn algorithm.
Args:
card_number: Credit card number string (digits only).
Returns:
True if the number passes Luhn validation, False otherwise.
"""
digits = [int(d) for d in card_number if d.isdigit()]
if len(digits) < 13 or len(digits) > 19:
return False
checksum = 0
for i, digit in enumerate(reversed(digits)):
d = digit
if i % 2 == 1:
d *= 2
if d > 9:
d -= 9
checksum += d
return checksum % 10 == 0
def detect_email(content: str) -> list[PIIMatch]:
"""Detect email addresses in content.
Args:
content: Text content to scan.
Returns:
List of detected email matches.
"""
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
return [
PIIMatch(
type="email",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_credit_card(content: str) -> list[PIIMatch]:
"""Detect credit card numbers in content using Luhn validation.
Detects cards in formats like:
- 1234567890123456
- 1234 5678 9012 3456
- 1234-5678-9012-3456
Args:
content: Text content to scan.
Returns:
List of detected credit card matches.
"""
# Match various credit card formats
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
matches = []
for match in re.finditer(pattern, content):
card_number = match.group()
# Validate with Luhn algorithm
if _luhn_checksum(card_number):
matches.append(
PIIMatch(
type="credit_card",
value=card_number,
start=match.start(),
end=match.end(),
)
)
return matches
def detect_ip(content: str) -> list[PIIMatch]:
"""Detect IP addresses in content using stdlib validation.
Validates both IPv4 and IPv6 addresses.
Args:
content: Text content to scan.
Returns:
List of detected IP address matches.
"""
matches = []
# IPv4 pattern
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
for match in re.finditer(ipv4_pattern, content):
ip_str = match.group()
try:
# Validate with stdlib
ipaddress.ip_address(ip_str)
matches.append(
PIIMatch(
type="ip",
value=ip_str,
start=match.start(),
end=match.end(),
)
)
except ValueError:
# Not a valid IP address
pass
return matches
def detect_mac_address(content: str) -> list[PIIMatch]:
"""Detect MAC addresses in content.
Detects formats like:
- 00:1A:2B:3C:4D:5E
- 00-1A-2B-3C-4D-5E
Args:
content: Text content to scan.
Returns:
List of detected MAC address matches.
"""
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
return [
PIIMatch(
type="mac_address",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_url(content: str) -> list[PIIMatch]:
"""Detect URLs in content using regex and stdlib validation.
Detects:
- http://example.com
- https://example.com/path
- www.example.com
- example.com/path
Args:
content: Text content to scan.
Returns:
List of detected URL matches.
"""
matches = []
# Pattern 1: URLs with scheme (http:// or https://)
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
for match in re.finditer(scheme_pattern, content):
url = match.group()
try:
result = urlparse(url)
if result.scheme in ("http", "https") and result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=match.start(),
end=match.end(),
)
)
except Exception: # noqa: S110, BLE001
# Invalid URL, skip
pass
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
# More conservative to avoid false positives
bare_pattern = r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?" # noqa: E501
for match in re.finditer(bare_pattern, content):
# Skip if already matched with scheme
if any(
m["start"] <= match.start() < m["end"] or m["start"] < match.end() <= m["end"]
for m in matches
):
continue
url = match.group()
# Only accept if it has a path or starts with www
# This reduces false positives like "example.com" in prose
if "/" in url or url.startswith("www."):
try:
# Add scheme for validation (required for urlparse to work correctly)
test_url = f"http://{url}"
result = urlparse(test_url)
if result.netloc and "." in result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=match.start(),
end=match.end(),
)
)
except Exception: # noqa: S110, BLE001
# Invalid URL, skip
pass
return matches
# Built-in detector registry
_BUILTIN_DETECTORS: dict[str, Callable[[str], list[PIIMatch]]] = {
"email": detect_email,
"credit_card": detect_credit_card,
"ip": detect_ip,
"mac_address": detect_mac_address,
"url": detect_url,
}
# ============================================================================
# Strategy Implementations
# ============================================================================
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
"""Replace PII with [REDACTED_TYPE] placeholders.
Args:
content: Original content.
matches: List of PII matches to redact.
Returns:
Content with PII redacted.
"""
if not matches:
return content
# Sort matches by start position in reverse to avoid offset issues
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
result = content
for match in sorted_matches:
replacement = f"[REDACTED_{match['type'].upper()}]"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
"""Partially mask PII, showing only last few characters.
Args:
content: Original content.
matches: List of PII matches to mask.
Returns:
Content with PII masked.
"""
if not matches:
return content
# Sort matches by start position in reverse
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
result = content
for match in sorted_matches:
value = match["value"]
pii_type = match["type"]
# Different masking strategies by type
if pii_type == "email":
# Show only domain: user@****.com
parts = value.split("@")
if len(parts) == 2:
domain_parts = parts[1].split(".")
if len(domain_parts) >= 2:
masked = f"{parts[0]}@****.{domain_parts[-1]}"
else:
masked = f"{parts[0]}@****"
else:
masked = "****"
elif pii_type == "credit_card":
# Show last 4: ****-****-****-1234
digits_only = "".join(c for c in value if c.isdigit())
separator = "-" if "-" in value else " " if " " in value else ""
if separator:
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
else:
masked = f"************{digits_only[-4:]}"
elif pii_type == "ip":
# Show last octet: *.*.*. 123
parts = value.split(".")
masked = f"*.*.*.{parts[-1]}" if len(parts) == 4 else "****"
elif pii_type == "mac_address":
# Show last byte: **:**:**:**:**:5E
separator = ":" if ":" in value else "-"
masked = (
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
)
elif pii_type == "url":
# Mask everything: [MASKED_URL]
masked = "[MASKED_URL]"
else:
# Default: show last 4 chars
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
result = result[: match["start"]] + masked + result[match["end"] :]
return result
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
"""Replace PII with deterministic hash including type information.
Args:
content: Original content.
matches: List of PII matches to hash.
Returns:
Content with PII replaced by hashes in format <type_hash:digest>.
"""
if not matches:
return content
# Sort matches by start position in reverse
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
result = content
for match in sorted_matches:
value = match["value"]
pii_type = match["type"]
# Create deterministic hash
hash_digest = hashlib.sha256(value.encode()).hexdigest()[:8]
replacement = f"<{pii_type}_hash:{hash_digest}>"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
# ============================================================================
# PIIMiddleware
# ============================================================================
class PIIMiddleware(AgentMiddleware): class PIIMiddleware(AgentMiddleware):
"""Detect and handle Personally Identifiable Information (PII) in agent conversations. """Detect and handle Personally Identifiable Information (PII) in agent conversations.
@@ -510,50 +127,34 @@ class PIIMiddleware(AgentMiddleware):
""" """
super().__init__() super().__init__()
self.pii_type = pii_type
self.strategy = strategy
self.apply_to_input = apply_to_input self.apply_to_input = apply_to_input
self.apply_to_output = apply_to_output self.apply_to_output = apply_to_output
self.apply_to_tool_results = apply_to_tool_results self.apply_to_tool_results = apply_to_tool_results
# Resolve detector self._resolved_rule: ResolvedRedactionRule = RedactionRule(
if detector is None: pii_type=pii_type,
# Use built-in detector strategy=strategy,
if pii_type not in _BUILTIN_DETECTORS: detector=detector,
msg = ( ).resolve()
f"Unknown PII type: {pii_type}. " self.pii_type = self._resolved_rule.pii_type
f"Must be one of {list(_BUILTIN_DETECTORS.keys())} " self.strategy = self._resolved_rule.strategy
"or provide a custom detector." self.detector = self._resolved_rule.detector
)
raise ValueError(msg)
self.detector = _BUILTIN_DETECTORS[pii_type]
elif isinstance(detector, str):
# Custom regex pattern
pattern = detector
def regex_detector(content: str) -> list[PIIMatch]:
return [
PIIMatch(
type=pii_type,
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
self.detector = regex_detector
else:
# Custom callable detector
self.detector = detector
@property @property
def name(self) -> str: def name(self) -> str:
"""Name of the middleware.""" """Name of the middleware."""
return f"{self.__class__.__name__}[{self.pii_type}]" return f"{self.__class__.__name__}[{self.pii_type}]"
def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
"""Apply the configured redaction rule to the provided content."""
matches = self.detector(content)
if not matches:
return content, []
sanitized = apply_strategy(content, matches, self.strategy)
return sanitized, matches
@hook_config(can_jump_to=["end"]) @hook_config(can_jump_to=["end"])
def before_model( # noqa: PLR0915 def before_model(
self, self,
state: AgentState, state: AgentState,
runtime: Runtime, # noqa: ARG002 runtime: Runtime, # noqa: ARG002
@@ -594,25 +195,9 @@ class PIIMiddleware(AgentMiddleware):
if last_user_idx is not None and last_user_msg and last_user_msg.content: if last_user_idx is not None and last_user_msg and last_user_msg.content:
# Detect PII in message content # Detect PII in message content
content = str(last_user_msg.content) content = str(last_user_msg.content)
matches = self.detector(content) new_content, matches = self._process_content(content)
if matches: if matches:
# Apply strategy
if self.strategy == "block":
raise PIIDetectionError(self.pii_type, matches)
if self.strategy == "redact":
new_content = _apply_redact_strategy(content, matches)
elif self.strategy == "mask":
new_content = _apply_mask_strategy(content, matches)
elif self.strategy == "hash":
new_content = _apply_hash_strategy(content, matches)
else:
# Should not reach here due to type hints
msg = f"Unknown strategy: {self.strategy}"
raise ValueError(msg)
# Create updated message
updated_message: AnyMessage = HumanMessage( updated_message: AnyMessage = HumanMessage(
content=new_content, content=new_content,
id=last_user_msg.id, id=last_user_msg.id,
@@ -641,26 +226,11 @@ class PIIMiddleware(AgentMiddleware):
continue continue
content = str(tool_msg.content) content = str(tool_msg.content)
matches = self.detector(content) new_content, matches = self._process_content(content)
if not matches: if not matches:
continue continue
# Apply strategy
if self.strategy == "block":
raise PIIDetectionError(self.pii_type, matches)
if self.strategy == "redact":
new_content = _apply_redact_strategy(content, matches)
elif self.strategy == "mask":
new_content = _apply_mask_strategy(content, matches)
elif self.strategy == "hash":
new_content = _apply_hash_strategy(content, matches)
else:
# Should not reach here due to type hints
msg = f"Unknown strategy: {self.strategy}"
raise ValueError(msg)
# Create updated tool message # Create updated tool message
updated_message = ToolMessage( updated_message = ToolMessage(
content=new_content, content=new_content,
@@ -716,26 +286,11 @@ class PIIMiddleware(AgentMiddleware):
# Detect PII in message content # Detect PII in message content
content = str(last_ai_msg.content) content = str(last_ai_msg.content)
matches = self.detector(content) new_content, matches = self._process_content(content)
if not matches: if not matches:
return None return None
# Apply strategy
if self.strategy == "block":
raise PIIDetectionError(self.pii_type, matches)
if self.strategy == "redact":
new_content = _apply_redact_strategy(content, matches)
elif self.strategy == "mask":
new_content = _apply_mask_strategy(content, matches)
elif self.strategy == "hash":
new_content = _apply_hash_strategy(content, matches)
else:
# Should not reach here due to type hints
msg = f"Unknown strategy: {self.strategy}"
raise ValueError(msg)
# Create updated message # Create updated message
updated_message = AIMessage( updated_message = AIMessage(
content=new_content, content=new_content,
@@ -749,3 +304,14 @@ class PIIMiddleware(AgentMiddleware):
new_messages[last_ai_idx] = updated_message new_messages[last_ai_idx] = updated_message
return {"messages": new_messages} return {"messages": new_messages}
__all__ = [
"PIIDetectionError",
"PIIMiddleware",
"detect_credit_card",
"detect_email",
"detect_ip",
"detect_mac_address",
"detect_url",
]

View File

@@ -0,0 +1,715 @@
"""Middleware that exposes a persistent shell tool to agents."""
from __future__ import annotations
import contextlib
import logging
import os
import queue
import signal
import subprocess
import tempfile
import threading
import time
import typing
import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from typing_extensions import NotRequired
from langchain.agents.middleware._execution import (
SHELL_TEMP_PREFIX,
BaseExecutionPolicy,
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
HostExecutionPolicy,
)
from langchain.agents.middleware._redaction import (
PIIDetectionError,
PIIMatch,
RedactionRule,
ResolvedRedactionRule,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langgraph.types import Command
from langchain.tools.tool_node import ToolCallRequest
LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
DEFAULT_TOOL_DESCRIPTION = (
"Execute a shell command inside a persistent session. Before running a command, "
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
"session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses."
)
def _cleanup_resources(
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
) -> None:
with contextlib.suppress(Exception):
session.stop(timeout)
if tempdir is not None:
with contextlib.suppress(Exception):
tempdir.cleanup()
@dataclass
class _SessionResources:
"""Container for per-run shell resources."""
session: ShellSession
tempdir: tempfile.TemporaryDirectory[str] | None
policy: BaseExecutionPolicy
_finalizer: weakref.finalize = field(init=False, repr=False)
def __post_init__(self) -> None:
self._finalizer = weakref.finalize(
self,
_cleanup_resources,
self.session,
self.tempdir,
self.policy.termination_timeout,
)
class ShellToolState(AgentState):
"""Agent state extension for tracking shell session resources."""
shell_session_resources: NotRequired[
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
]
@dataclass(frozen=True)
class CommandExecutionResult:
"""Structured result from command execution."""
output: str
exit_code: int | None
timed_out: bool
truncated_by_lines: bool
truncated_by_bytes: bool
total_lines: int
total_bytes: int
class ShellSession:
"""Persistent shell session that supports sequential command execution."""
def __init__(
self,
workspace: Path,
policy: BaseExecutionPolicy,
command: tuple[str, ...],
environment: Mapping[str, str],
) -> None:
self._workspace = workspace
self._policy = policy
self._command = command
self._environment = dict(environment)
self._process: subprocess.Popen[str] | None = None
self._stdin: Any = None
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
self._lock = threading.Lock()
self._stdout_thread: threading.Thread | None = None
self._stderr_thread: threading.Thread | None = None
self._terminated = False
def start(self) -> None:
"""Start the shell subprocess and reader threads."""
if self._process and self._process.poll() is None:
return
self._process = self._policy.spawn(
workspace=self._workspace,
env=self._environment,
command=self._command,
)
if (
self._process.stdin is None
or self._process.stdout is None
or self._process.stderr is None
):
msg = "Failed to initialize shell session pipes."
raise RuntimeError(msg)
self._stdin = self._process.stdin
self._terminated = False
self._queue = queue.Queue()
self._stdout_thread = threading.Thread(
target=self._enqueue_stream,
args=(self._process.stdout, "stdout"),
daemon=True,
)
self._stderr_thread = threading.Thread(
target=self._enqueue_stream,
args=(self._process.stderr, "stderr"),
daemon=True,
)
self._stdout_thread.start()
self._stderr_thread.start()
def restart(self) -> None:
"""Restart the shell process."""
self.stop(self._policy.termination_timeout)
self.start()
def stop(self, timeout: float) -> None:
"""Stop the shell subprocess."""
if not self._process:
return
if self._process.poll() is None and not self._terminated:
try:
self._stdin.write("exit\n")
self._stdin.flush()
except (BrokenPipeError, OSError):
LOGGER.debug(
"Failed to write exit command; terminating shell session.",
exc_info=True,
)
try:
if self._process.wait(timeout=timeout) is None:
self._kill_process()
except subprocess.TimeoutExpired:
self._kill_process()
finally:
self._terminated = True
with contextlib.suppress(Exception):
self._stdin.close()
self._process = None
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
"""Execute a command in the persistent shell."""
if not self._process or self._process.poll() is not None:
msg = "Shell session is not running."
raise RuntimeError(msg)
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
deadline = time.monotonic() + timeout
with self._lock:
self._drain_queue()
payload = command if command.endswith("\n") else f"{command}\n"
self._stdin.write(payload)
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
self._stdin.flush()
return self._collect_output(marker, deadline, timeout)
def _collect_output(
self,
marker: str,
deadline: float,
timeout: float,
) -> CommandExecutionResult:
collected: list[str] = []
total_lines = 0
total_bytes = 0
truncated_by_lines = False
truncated_by_bytes = False
exit_code: int | None = None
timed_out = False
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
timed_out = True
break
try:
source, data = self._queue.get(timeout=remaining)
except queue.Empty:
timed_out = True
break
if data is None:
continue
if source == "stdout" and data.startswith(marker):
_, _, status = data.partition(" ")
exit_code = self._safe_int(status.strip())
break
total_lines += 1
encoded = data.encode("utf-8", "replace")
total_bytes += len(encoded)
if total_lines > self._policy.max_output_lines:
truncated_by_lines = True
continue
if (
self._policy.max_output_bytes is not None
and total_bytes > self._policy.max_output_bytes
):
truncated_by_bytes = True
continue
if source == "stderr":
stripped = data.rstrip("\n")
collected.append(f"[stderr] {stripped}")
if data.endswith("\n"):
collected.append("\n")
else:
collected.append(data)
if timed_out:
LOGGER.warning(
"Command timed out after %.2f seconds; restarting shell session.",
timeout,
)
self.restart()
return CommandExecutionResult(
output="",
exit_code=None,
timed_out=True,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
output = "".join(collected)
return CommandExecutionResult(
output=output,
exit_code=exit_code,
timed_out=False,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
def _kill_process(self) -> None:
if not self._process:
return
if hasattr(os, "killpg"):
with contextlib.suppress(ProcessLookupError):
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
else: # pragma: no cover
with contextlib.suppress(ProcessLookupError):
self._process.kill()
def _enqueue_stream(self, stream: Any, label: str) -> None:
for line in iter(stream.readline, ""):
self._queue.put((label, line))
self._queue.put((label, None))
def _drain_queue(self) -> None:
while True:
try:
self._queue.get_nowait()
except queue.Empty:
break
@staticmethod
def _safe_int(value: str) -> int | None:
with contextlib.suppress(ValueError):
return int(value)
return None
class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool."""
command: str | None = None
restart: bool | None = None
@model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput:
if self.command is None and not self.restart:
msg = "Shell tool requires either 'command' or 'restart'."
raise ValueError(msg)
if self.command is not None and self.restart:
msg = "Specify only one of 'command' or 'restart'."
raise ValueError(msg)
return self
class _PersistentShellTool(BaseTool):
"""Tool wrapper that relies on middleware interception for execution."""
name: str = "shell"
description: str = DEFAULT_TOOL_DESCRIPTION
args_schema: type[BaseModel] = _ShellToolInput
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
super().__init__()
self._middleware = middleware
if description is not None:
self.description = description
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
raise RuntimeError(msg)
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Middleware that registers a persistent shell tool for agents.
The middleware exposes a single long-lived shell session. Use the execution policy to
match your deployment's security posture:
* ``HostExecutionPolicy`` - full host access; best for trusted environments where the
agent already runs inside a container or VM that provides isolation.
* ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
syscall/filesystem restrictions when the CLI is available.
* ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
providing harder isolation, optional read-only root filesystems, and user remapping.
When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
"""
def __init__(
self,
workspace_root: str | Path | None = None,
*,
startup_commands: tuple[str, ...] | list[str] | str | None = None,
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
) -> None:
"""Initialize the middleware.
Args:
workspace_root: Base directory for the shell session. If omitted, a temporary
directory is created when the agent starts and removed when it ends.
startup_commands: Optional commands executed sequentially after the session starts.
shutdown_commands: Optional commands executed before the session shuts down.
execution_policy: Execution policy controlling timeouts, output limits, and resource
configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
redaction_rules: Optional redaction rules to sanitize command output before
returning it to the model.
tool_description: Optional override for the registered shell tool description.
shell_command: Optional shell executable (string) or argument sequence used to
launch the persistent session. Defaults to an implementation-defined bash command.
env: Optional environment variables to supply to the shell session. Values are
coerced to strings before command execution. If omitted, the session inherits the
parent process environment.
"""
super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
if execution_policy is not None:
self._execution_policy = execution_policy
else:
self._execution_policy = HostExecutionPolicy()
rules = redaction_rules or ()
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
rule.resolve() for rule in rules
)
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)
description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool]
@staticmethod
def _normalize_commands(
commands: tuple[str, ...] | list[str] | str | None,
) -> tuple[str, ...]:
if commands is None:
return ()
if isinstance(commands, str):
return (commands,)
return tuple(commands)
@staticmethod
def _normalize_shell_command(
shell_command: Sequence[str] | str | None,
) -> tuple[str, ...]:
if shell_command is None:
return ("/bin/bash",)
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
if not normalized:
msg = "Shell command must contain at least one argument."
raise ValueError(msg)
return normalized
@staticmethod
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
if env is None:
return None
normalized: dict[str, str] = {}
for key, value in env.items():
if not isinstance(key, str):
msg = "Environment variable names must be strings."
raise TypeError(msg)
normalized[key] = str(value)
return normalized
def before_agent(self, _state: ShellToolState, _runtime: Any) -> dict[str, Any] | None:
"""Start the shell session and run startup commands."""
resources = self._create_resources()
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, _runtime: Any) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
return self.before_agent(state, _runtime)
def after_agent(self, state: ShellToolState, _runtime: Any) -> None:
"""Run shutdown commands and release resources when an agent completes."""
resources = self._ensure_resources(state)
try:
self._run_shutdown_commands(resources.session)
finally:
resources._finalizer()
async def aafter_agent(self, state: ShellToolState, _runtime: Any) -> None:
"""Async counterpart to `after_agent`."""
return self.after_agent(state, _runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
)
raise ToolException(msg)
return resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
tempdir: tempfile.TemporaryDirectory[str] | None = None
if workspace is None:
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
workspace_path = Path(tempdir.name)
else:
workspace_path = workspace
workspace_path.mkdir(parents=True, exist_ok=True)
session = ShellSession(
workspace_path,
self._execution_policy,
self._shell_command,
self._environment or {},
)
try:
session.start()
LOGGER.info("Started shell session in %s", workspace_path)
self._run_startup_commands(session)
except BaseException:
LOGGER.exception("Starting shell session failed; cleaning up resources.")
session.stop(self._execution_policy.termination_timeout)
if tempdir is not None:
tempdir.cleanup()
raise
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
def _run_startup_commands(self, session: ShellSession) -> None:
if not self._startup_commands:
return
for command in self._startup_commands:
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
if result.timed_out or (result.exit_code not in (0, None)):
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
raise RuntimeError(msg)
def _run_shutdown_commands(self, session: ShellSession) -> None:
if not self._shutdown_commands:
return
for command in self._shutdown_commands:
try:
result = session.execute(command, timeout=self._execution_policy.command_timeout)
if result.timed_out:
LOGGER.warning("Shutdown command '%s' timed out.", command)
elif result.exit_code not in (0, None):
LOGGER.warning(
"Shutdown command '%s' exited with %s.", command, result.exit_code
)
except (RuntimeError, ToolException, OSError) as exc:
LOGGER.warning(
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
)
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
"""Apply configured redaction rules to command output."""
matches_by_type: dict[str, list[PIIMatch]] = {}
updated = content
for rule in self._redaction_rules:
updated, matches = rule.apply(updated)
if matches:
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
return updated, matches_by_type
def _run_shell_tool(
self,
resources: _SessionResources,
payload: dict[str, Any],
*,
tool_call_id: str | None,
) -> Any:
session = resources.session
if payload.get("restart"):
LOGGER.info("Restarting shell session on request.")
try:
session.restart()
self._run_startup_commands(session)
except BaseException as err:
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
msg = "Failed to restart shell session."
raise ToolException(msg) from err
message = "Shell session restarted."
return self._format_tool_message(message, tool_call_id, status="success")
command = payload.get("command")
if not command or not isinstance(command, str):
msg = "Shell tool expects a 'command' string when restart is not requested."
raise ToolException(msg)
LOGGER.info("Executing shell command: %s", command)
result = session.execute(command, timeout=self._execution_policy.command_timeout)
if result.timed_out:
timeout_seconds = self._execution_policy.command_timeout
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
return self._format_tool_message(
message,
tool_call_id,
status="error",
artifact={
"timed_out": True,
"exit_code": None,
},
)
try:
sanitized_output, matches = self._apply_redactions(result.output)
except PIIDetectionError as error:
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
message = f"Output blocked: detected {error.pii_type}."
return self._format_tool_message(
message,
tool_call_id,
status="error",
artifact={
"timed_out": False,
"exit_code": result.exit_code,
"matches": {error.pii_type: error.matches},
},
)
sanitized_output = sanitized_output or "<no output>"
if result.truncated_by_lines:
sanitized_output = (
f"{sanitized_output.rstrip()}\n\n"
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
f"(observed {result.total_lines})."
)
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
sanitized_output = (
f"{sanitized_output.rstrip()}\n\n"
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
f"(observed {result.total_bytes})."
)
if result.exit_code not in (0, None):
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
final_status: Literal["success", "error"] = "error"
else:
final_status = "success"
artifact = {
"timed_out": False,
"exit_code": result.exit_code,
"truncated_by_lines": result.truncated_by_lines,
"truncated_by_bytes": result.truncated_by_bytes,
"total_lines": result.total_lines,
"total_bytes": result.total_bytes,
"redaction_matches": matches,
}
return self._format_tool_message(
sanitized_output,
tool_call_id,
status=final_status,
artifact=artifact,
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept local shell tool calls and execute them via the managed session."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return handler(request)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
"""Async interception mirroring the synchronous tool handler."""
if isinstance(request.tool, _PersistentShellTool):
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
request.tool_call["args"],
tool_call_id=request.tool_call.get("id"),
)
return await handler(request)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
artifact = artifact or {}
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=self._tool.name,
status=status,
artifact=artifact,
)
__all__ = [
"CodexSandboxExecutionPolicy",
"DockerExecutionPolicy",
"HostExecutionPolicy",
"RedactionRule",
"ShellToolMiddleware",
]

View File

@@ -0,0 +1,404 @@
from __future__ import annotations
import os
import shutil
from pathlib import Path
from typing import Any
import pytest
from langchain.agents.middleware.shell_tool import (
HostExecutionPolicy,
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
)
from langchain.agents.middleware import _execution
def _make_resource(
*,
with_prlimit: bool,
has_rlimit_as: bool = True,
) -> Any:
"""Create a fake ``resource`` module for testing."""
class _BaseResource:
RLIMIT_CPU = 0
RLIMIT_DATA = 2
if has_rlimit_as:
RLIMIT_AS = 1
def __init__(self) -> None:
self.prlimit_calls: list[tuple[int, int, tuple[int, int]]] = []
self.setrlimit_calls: list[tuple[int, tuple[int, int]]] = []
def setrlimit(self, resource_name: int, limits: tuple[int, int]) -> None:
self.setrlimit_calls.append((resource_name, limits))
if with_prlimit:
class _Resource(_BaseResource):
def prlimit(self, pid: int, resource_name: int, limits: tuple[int, int]) -> None:
self.prlimit_calls.append((pid, resource_name, limits))
else:
class _Resource(_BaseResource):
pass
return _Resource()
def test_host_policy_validations() -> None:
with pytest.raises(ValueError):
HostExecutionPolicy(max_output_lines=0)
with pytest.raises(ValueError):
HostExecutionPolicy(cpu_time_seconds=0)
with pytest.raises(ValueError):
HostExecutionPolicy(memory_bytes=-1)
def test_host_policy_requires_resource_for_limits(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution, "resource", None, raising=False)
with pytest.raises(RuntimeError):
HostExecutionPolicy(cpu_time_seconds=1)
def test_host_policy_applies_prlimit(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
fake_resource = _make_resource(with_prlimit=True)
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
monkeypatch.setattr(_execution.sys, "platform", "linux")
recorded: dict[str, Any] = {}
class DummyProcess:
pid = 1234
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["command"] = list(command)
recorded["env"] = dict(env)
recorded["cwd"] = cwd
recorded["preexec_fn"] = preexec_fn
recorded["start_new_session"] = start_new_session
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
policy = HostExecutionPolicy(cpu_time_seconds=2, memory_bytes=4096)
env = {"PATH": os.environ.get("PATH", ""), "VAR": "1"}
process = policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
assert process is not None
assert recorded["preexec_fn"] is None
assert recorded["start_new_session"] is True
assert fake_resource.prlimit_calls == [
(1234, fake_resource.RLIMIT_CPU, (2, 2)),
(1234, fake_resource.RLIMIT_AS, (4096, 4096)),
]
assert fake_resource.setrlimit_calls == []
def test_host_policy_uses_preexec_on_macos(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
fake_resource = _make_resource(with_prlimit=False)
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
monkeypatch.setattr(_execution.sys, "platform", "darwin")
captured: dict[str, Any] = {}
class DummyProcess:
pid = 4321
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
captured["preexec_fn"] = preexec_fn
captured["start_new_session"] = start_new_session
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
policy = HostExecutionPolicy(cpu_time_seconds=5, memory_bytes=8192)
env = {"PATH": os.environ.get("PATH", "")}
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
preexec_fn = captured["preexec_fn"]
assert callable(preexec_fn)
assert captured["start_new_session"] is True
preexec_fn()
# macOS fallback should use setrlimit
assert fake_resource.setrlimit_calls == [
(fake_resource.RLIMIT_CPU, (5, 5)),
(fake_resource.RLIMIT_AS, (8192, 8192)),
]
def test_host_policy_respects_process_group_flag(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
fake_resource = _make_resource(with_prlimit=True)
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
monkeypatch.setattr(_execution.sys, "platform", "linux")
recorded: dict[str, Any] = {}
class DummyProcess:
pid = 1111
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["start_new_session"] = start_new_session
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
policy = HostExecutionPolicy(create_process_group=False)
env = {"PATH": os.environ.get("PATH", "")}
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
assert recorded["start_new_session"] is False
def test_host_policy_falls_back_to_rlimit_data(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
fake_resource = _make_resource(with_prlimit=True, has_rlimit_as=False)
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
monkeypatch.setattr(_execution.sys, "platform", "linux")
class DummyProcess:
pid = 2222
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
policy = HostExecutionPolicy(cpu_time_seconds=7, memory_bytes=2048)
env = {"PATH": os.environ.get("PATH", "")}
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
assert fake_resource.prlimit_calls == [
(2222, fake_resource.RLIMIT_CPU, (7, 7)),
(2222, fake_resource.RLIMIT_DATA, (2048, 2048)),
]
@pytest.mark.skipif(
shutil.which("codex") is None,
reason="codex CLI not available on PATH",
)
def test_codex_policy_spawns_codex_cli(monkeypatch, tmp_path: Path) -> None:
recorded: dict[str, list[str]] = {}
class DummyProcess:
pass
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["command"] = list(command)
assert cwd == tmp_path
assert env["TEST_VAR"] == "1"
assert preexec_fn is None
assert not start_new_session
return DummyProcess()
monkeypatch.setattr(
"langchain.agents.middleware._execution._launch_subprocess",
fake_launch,
)
policy = CodexSandboxExecutionPolicy(
platform="linux",
config_overrides={"sandbox_permissions": ["disk-full-read-access"]},
)
env = {"TEST_VAR": "1"}
policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",))
expected = [
shutil.which("codex"),
"sandbox",
"linux",
"-c",
'sandbox_permissions=["disk-full-read-access"]',
"--",
"/bin/bash",
]
assert recorded["command"] == expected
def test_codex_policy_auto_platform_linux(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.sys, "platform", "linux")
policy = CodexSandboxExecutionPolicy(platform="auto")
assert policy._determine_platform() == "linux"
def test_codex_policy_auto_platform_macos(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.sys, "platform", "darwin")
policy = CodexSandboxExecutionPolicy(platform="auto")
assert policy._determine_platform() == "macos"
def test_codex_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.shutil, "which", lambda _: None)
policy = CodexSandboxExecutionPolicy(binary="codex")
with pytest.raises(RuntimeError):
policy._resolve_binary()
def test_codex_policy_auto_platform_failure(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.sys, "platform", "win32")
policy = CodexSandboxExecutionPolicy(platform="auto")
with pytest.raises(RuntimeError):
policy._determine_platform()
def test_codex_policy_formats_override_values() -> None:
policy = CodexSandboxExecutionPolicy()
assert policy._format_override({"a": 1}) == '{"a": 1}'
class Custom:
def __str__(self) -> str:
return "custom"
assert policy._format_override(Custom()) == "custom"
def test_codex_policy_sorts_config_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/codex")
policy = CodexSandboxExecutionPolicy(
config_overrides={"b": 2, "a": 1},
platform="linux",
)
command = policy._build_command(("echo",))
indices = [i for i, part in enumerate(command) if part == "-c"]
override_values = [command[i + 1] for i in indices]
assert override_values == ["a=1", "b=2"]
@pytest.mark.skipif(
shutil.which("docker") is None,
reason="docker CLI not available on PATH",
)
def test_docker_policy_spawns_docker_run(monkeypatch, tmp_path: Path) -> None:
recorded: dict[str, list[str]] = {}
class DummyProcess:
pass
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["command"] = list(command)
assert cwd == tmp_path
assert "PATH" in env # host environment should retain system PATH
assert not start_new_session
return DummyProcess()
monkeypatch.setattr(
"langchain.agents.middleware._execution._launch_subprocess",
fake_launch,
)
policy = DockerExecutionPolicy(
image="ubuntu:22.04",
memory_bytes=4096,
extra_run_args=("--ipc", "host"),
)
env = {"PATH": "/bin"}
policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",))
command = recorded["command"]
assert command[0] == shutil.which("docker")
assert command[1:4] == ["run", "-i", "--rm"]
assert "--memory" in command
assert "4096" in command
assert "-v" in command and any(str(tmp_path) in part for part in command)
assert "-w" in command
w_index = command.index("-w")
assert command[w_index + 1] == str(tmp_path)
assert "-e" in command and "PATH=/bin" in command
assert command[-2:] == ["ubuntu:22.04", "/bin/bash"]
def test_docker_policy_rejects_cpu_limit() -> None:
with pytest.raises(RuntimeError):
DockerExecutionPolicy(cpu_time_seconds=1)
def test_docker_policy_validates_memory() -> None:
with pytest.raises(ValueError):
DockerExecutionPolicy(memory_bytes=0)
def test_docker_policy_skips_mount_for_temp_workspace(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker")
recorded: dict[str, list[str]] = {}
class DummyProcess:
pass
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["command"] = list(command)
assert cwd == workspace
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
workspace = tmp_path / f"{_execution.SHELL_TEMP_PREFIX}case"
workspace.mkdir()
policy = DockerExecutionPolicy(cpus="1.5")
env = {"PATH": "/bin"}
policy.spawn(workspace=workspace, env=env, command=("/bin/sh",))
command = recorded["command"]
assert "-v" not in command
assert "-w" in command
w_index = command.index("-w")
assert command[w_index + 1] == "/"
assert "--cpus" in command
assert "--network" in command and "none" in command
assert command[-2:] == [policy.image, "/bin/sh"]
def test_docker_policy_validates_cpus() -> None:
with pytest.raises(ValueError):
DockerExecutionPolicy(cpus=" ")
def test_docker_policy_validates_user() -> None:
with pytest.raises(ValueError):
DockerExecutionPolicy(user=" ")
def test_docker_policy_read_only_and_user(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker")
recorded: dict[str, list[str]] = {}
class DummyProcess:
pass
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
recorded["command"] = list(command)
return DummyProcess()
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
workspace = tmp_path
policy = DockerExecutionPolicy(read_only_rootfs=True, user="1000:1000")
policy.spawn(workspace=workspace, env={"PATH": "/bin"}, command=("/bin/sh",))
command = recorded["command"]
assert "--read-only" in command
assert "--user" in command
user_index = command.index("--user")
assert command[user_index + 1] == "1000:1000"
def test_docker_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(_execution.shutil, "which", lambda _: None)
policy = DockerExecutionPolicy()
with pytest.raises(RuntimeError):
policy._resolve_binary()

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import time
import gc
import tempfile
from pathlib import Path
import pytest
from langchain.agents.middleware.shell_tool import (
HostExecutionPolicy,
ShellToolMiddleware,
_SessionResources,
RedactionRule,
)
from langchain.agents.middleware.types import AgentState
def _empty_state() -> AgentState:
return {"messages": []} # type: ignore[return-value]
def test_executes_command_and_persists_state(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
assert isinstance(result, str)
assert result.strip() == "/"
echo_result = middleware._run_shell_tool(
resources, {"command": "echo ready"}, tool_call_id=None
)
assert "ready" in echo_result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_restart_resets_session_environment(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
restart_message = middleware._run_shell_tool(
resources, {"restart": True}, tool_call_id=None
)
assert "restarted" in restart_message.lower()
resources = middleware._ensure_resources(state) # reacquire after restart
result = middleware._run_shell_tool(
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
)
assert "unset" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_indicator_present(tmp_path: Path) -> None:
policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
assert "Output truncated" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_timeout_returns_error(tmp_path: Path) -> None:
policy = HostExecutionPolicy(command_timeout=0.5)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
start = time.monotonic()
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
elapsed = time.monotonic() - start
assert elapsed < policy.command_timeout + 2.0
assert "timed out" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_redaction_policy_applies(tmp_path: Path) -> None:
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
message = middleware._run_shell_tool(
resources,
{"command": "printf 'Contact: user@example.com\\n'"},
tool_call_id=None,
)
assert "[REDACTED_EMAIL]" in message
assert "user@example.com" not in message
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_and_shutdown_commands(tmp_path: Path) -> None:
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(
workspace_root=workspace,
startup_commands=("touch startup.txt",),
shutdown_commands=("touch shutdown.txt",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "startup.txt").exists()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
assert (workspace / "shutdown.txt").exists()
def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
policy = HostExecutionPolicy(termination_timeout=0.1)
class DummySession:
def __init__(self) -> None:
self.stopped: bool = False
def stop(self, timeout: float) -> None: # noqa: ARG002
self.stopped = True
session = DummySession()
tempdir = tempfile.TemporaryDirectory(dir=tmp_path)
tempdir_path = Path(tempdir.name)
resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type]
finalizer = resources._finalizer
# Drop our last strong reference and force collection.
del resources
gc.collect()
assert not finalizer.alive
assert session.stopped
assert not tempdir_path.exists()

View File

@@ -6,6 +6,7 @@ from langchain_anthropic.middleware.anthropic_tools import (
StateClaudeMemoryMiddleware, StateClaudeMemoryMiddleware,
StateClaudeTextEditorMiddleware, StateClaudeTextEditorMiddleware,
) )
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
from langchain_anthropic.middleware.file_search import ( from langchain_anthropic.middleware.file_search import (
StateFileSearchMiddleware, StateFileSearchMiddleware,
) )
@@ -15,6 +16,7 @@ from langchain_anthropic.middleware.prompt_caching import (
__all__ = [ __all__ = [
"AnthropicPromptCachingMiddleware", "AnthropicPromptCachingMiddleware",
"ClaudeBashToolMiddleware",
"FilesystemClaudeMemoryMiddleware", "FilesystemClaudeMemoryMiddleware",
"FilesystemClaudeTextEditorMiddleware", "FilesystemClaudeTextEditorMiddleware",
"StateClaudeMemoryMiddleware", "StateClaudeMemoryMiddleware",

View File

@@ -0,0 +1,92 @@
"""Anthropic-specific middleware for the Claude bash tool."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import ToolMessage
from langgraph.types import Command
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
class ClaudeBashToolMiddleware(ShellToolMiddleware):
"""Middleware that exposes Anthropic's native bash tool to models."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize middleware without registering a client-side tool."""
kwargs["shell_command"] = ("/bin/bash",)
super().__init__(*args, **kwargs)
# Remove the base tool so Claude's native descriptor is the sole entry.
self._tool = None # type: ignore[assignment]
self.tools = []
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
"""Ensure the Claude bash descriptor is available to the model."""
tools = request.tools
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
request = request.override(tools=tools)
return handler(request)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Command | ToolMessage],
) -> Command | ToolMessage:
"""Intercept Claude bash tool calls and execute them locally."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
) -> Command | ToolMessage:
"""Async interception mirroring the synchronous implementation."""
tool_call = request.tool_call
if tool_call.get("name") != "bash":
return await handler(request)
resources = self._ensure_resources(request.state)
return self._run_shell_tool(
resources,
tool_call["args"],
tool_call_id=tool_call.get("id"),
)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
"""Format tool responses using Claude's bash descriptor."""
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=_CLAUDE_BASH_DESCRIPTOR["name"],
status=status,
artifact=artifact or {},
)
__all__ = ["ClaudeBashToolMiddleware"]

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from langchain_core.messages.tool import ToolCall
pytest.importorskip(
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
)
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import ToolMessage
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
middleware = ClaudeBashToolMiddleware()
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
monkeypatch.setattr(
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
)
tool_call: ToolCall = {
"name": "bash",
"args": {"command": "echo hi"},
"id": "call-1",
}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
handler_called = False
def handler(_: ToolCallRequest) -> ToolMessage:
nonlocal handler_called
handler_called = True
return ToolMessage(content="should not be used", tool_call_id="call-1")
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel
assert handler_called is False
def test_wrap_tool_call_passes_through_other_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
middleware = ClaudeBashToolMiddleware()
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
request = ToolCallRequest(
tool_call=tool_call,
tool=MagicMock(),
state={},
runtime=None, # type: ignore[arg-type]
)
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
def handler(_: ToolCallRequest) -> ToolMessage:
return sentinel
result = middleware.wrap_tool_call(request, handler)
assert result is sentinel