mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
feat(langchain_v1): Add ShellToolMiddleware and ClaudeBashToolMiddleware (#33527)
- Both middleware share the same implementation, the only difference is one uses Claude's server-side tool definition, whereas the other one uses a generic tool definition compatible with all models - Implemented 3 execution policies (responsible for actually running the shell process) - HostExecutionPolicy runs the shell as subprocess, appropriate for already sandboxed environments, eg when run inside a dedicated docker container - CodexSandboxExecutionPolicy runs the shell using the sandbox command from the Codex CLI which implements sandboxing techniques for Linux and Mac OS. - DockerExecutionPolicy runs the shell inside a dedicated Docker container for isolation. - Implements all behaviours described in https://docs.claude.com/en/docs/agents-and-tools/tool-use/bash-tool#handle-large-outputs including timeouts, truncation, output redaction, etc --------- Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -17,6 +17,13 @@ from .human_in_the_loop import (
|
|||||||
from .model_call_limit import ModelCallLimitMiddleware
|
from .model_call_limit import ModelCallLimitMiddleware
|
||||||
from .model_fallback import ModelFallbackMiddleware
|
from .model_fallback import ModelFallbackMiddleware
|
||||||
from .pii import PIIDetectionError, PIIMiddleware
|
from .pii import PIIDetectionError, PIIMiddleware
|
||||||
|
from .shell_tool import (
|
||||||
|
CodexSandboxExecutionPolicy,
|
||||||
|
DockerExecutionPolicy,
|
||||||
|
HostExecutionPolicy,
|
||||||
|
RedactionRule,
|
||||||
|
ShellToolMiddleware,
|
||||||
|
)
|
||||||
from .summarization import SummarizationMiddleware
|
from .summarization import SummarizationMiddleware
|
||||||
from .todo import TodoListMiddleware
|
from .todo import TodoListMiddleware
|
||||||
from .tool_call_limit import ToolCallLimitMiddleware
|
from .tool_call_limit import ToolCallLimitMiddleware
|
||||||
@@ -42,7 +49,10 @@ __all__ = [
|
|||||||
"AgentMiddleware",
|
"AgentMiddleware",
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"ClearToolUsesEdit",
|
"ClearToolUsesEdit",
|
||||||
|
"CodexSandboxExecutionPolicy",
|
||||||
"ContextEditingMiddleware",
|
"ContextEditingMiddleware",
|
||||||
|
"DockerExecutionPolicy",
|
||||||
|
"HostExecutionPolicy",
|
||||||
"HumanInTheLoopMiddleware",
|
"HumanInTheLoopMiddleware",
|
||||||
"InterruptOnConfig",
|
"InterruptOnConfig",
|
||||||
"LLMToolEmulator",
|
"LLMToolEmulator",
|
||||||
@@ -53,6 +63,8 @@ __all__ = [
|
|||||||
"ModelResponse",
|
"ModelResponse",
|
||||||
"PIIDetectionError",
|
"PIIDetectionError",
|
||||||
"PIIMiddleware",
|
"PIIMiddleware",
|
||||||
|
"RedactionRule",
|
||||||
|
"ShellToolMiddleware",
|
||||||
"SummarizationMiddleware",
|
"SummarizationMiddleware",
|
||||||
"TodoListMiddleware",
|
"TodoListMiddleware",
|
||||||
"ToolCallLimitMiddleware",
|
"ToolCallLimitMiddleware",
|
||||||
|
|||||||
388
libs/langchain_v1/langchain/agents/middleware/_execution.py
Normal file
388
libs/langchain_v1/langchain/agents/middleware/_execution.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
"""Execution policies for the persistent shell middleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try: # pragma: no cover - optional dependency on POSIX platforms
|
||||||
|
import resource
|
||||||
|
except ImportError: # pragma: no cover - non-POSIX systems
|
||||||
|
resource = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
SHELL_TEMP_PREFIX = "langchain-shell-"
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_subprocess(
|
||||||
|
command: Sequence[str],
|
||||||
|
*,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
cwd: Path,
|
||||||
|
preexec_fn: typing.Callable[[], None] | None,
|
||||||
|
start_new_session: bool,
|
||||||
|
) -> subprocess.Popen[str]:
|
||||||
|
return subprocess.Popen( # noqa: S603
|
||||||
|
list(command),
|
||||||
|
stdin=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
errors="replace",
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
preexec_fn=preexec_fn, # noqa: PLW1509
|
||||||
|
start_new_session=start_new_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseExecutionPolicy(abc.ABC):
|
||||||
|
"""Configuration contract for persistent shell sessions.
|
||||||
|
|
||||||
|
Concrete subclasses encapsulate how a shell process is launched and constrained.
|
||||||
|
Each policy documents its security guarantees and the operating environments in
|
||||||
|
which it is appropriate. Use :class:`HostExecutionPolicy` for trusted, same-host
|
||||||
|
execution; :class:`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is
|
||||||
|
available and you want additional syscall restrictions; and
|
||||||
|
:class:`DockerExecutionPolicy` for container-level isolation using Docker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
command_timeout: float = 30.0
|
||||||
|
startup_timeout: float = 30.0
|
||||||
|
termination_timeout: float = 10.0
|
||||||
|
max_output_lines: int = 100
|
||||||
|
max_output_bytes: int | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.max_output_lines <= 0:
|
||||||
|
msg = "max_output_lines must be positive."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def spawn(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
workspace: Path,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
command: Sequence[str],
|
||||||
|
) -> subprocess.Popen[str]:
|
||||||
|
"""Launch the persistent shell process."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HostExecutionPolicy(BaseExecutionPolicy):
|
||||||
|
"""Run the shell directly on the host process.
|
||||||
|
|
||||||
|
This policy is best suited for trusted or single-tenant environments (CI jobs,
|
||||||
|
developer workstations, pre-sandboxed containers) where the agent must access the
|
||||||
|
host filesystem and tooling without additional isolation. It enforces optional CPU
|
||||||
|
and memory limits to prevent runaway commands but offers **no** filesystem or network
|
||||||
|
sandboxing; commands can modify anything the process user can reach.
|
||||||
|
|
||||||
|
On Linux platforms resource limits are applied with ``resource.prlimit`` after the
|
||||||
|
shell starts. On macOS, where ``prlimit`` is unavailable, limits are set in a
|
||||||
|
``preexec_fn`` before ``exec``. In both cases the shell runs in its own process group
|
||||||
|
so timeouts can terminate the full subtree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cpu_time_seconds: int | None = None
|
||||||
|
memory_bytes: int | None = None
|
||||||
|
create_process_group: bool = True
|
||||||
|
|
||||||
|
_limits_requested: bool = field(init=False, repr=False, default=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
if self.cpu_time_seconds is not None and self.cpu_time_seconds <= 0:
|
||||||
|
msg = "cpu_time_seconds must be positive if provided."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.memory_bytes is not None and self.memory_bytes <= 0:
|
||||||
|
msg = "memory_bytes must be positive if provided."
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._limits_requested = any(
|
||||||
|
value is not None for value in (self.cpu_time_seconds, self.memory_bytes)
|
||||||
|
)
|
||||||
|
if self._limits_requested and resource is None:
|
||||||
|
msg = (
|
||||||
|
"HostExecutionPolicy cpu/memory limits require the Python 'resource' module. "
|
||||||
|
"Either remove the limits or run on a POSIX platform."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
def spawn(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
workspace: Path,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
command: Sequence[str],
|
||||||
|
) -> subprocess.Popen[str]:
|
||||||
|
process = _launch_subprocess(
|
||||||
|
list(command),
|
||||||
|
env=env,
|
||||||
|
cwd=workspace,
|
||||||
|
preexec_fn=self._create_preexec_fn(),
|
||||||
|
start_new_session=self.create_process_group,
|
||||||
|
)
|
||||||
|
self._apply_post_spawn_limits(process)
|
||||||
|
return process
|
||||||
|
|
||||||
|
def _create_preexec_fn(self) -> typing.Callable[[], None] | None:
|
||||||
|
if not self._limits_requested or self._can_use_prlimit():
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _configure() -> None: # pragma: no cover - depends on OS
|
||||||
|
if self.cpu_time_seconds is not None:
|
||||||
|
limit = (self.cpu_time_seconds, self.cpu_time_seconds)
|
||||||
|
resource.setrlimit(resource.RLIMIT_CPU, limit)
|
||||||
|
if self.memory_bytes is not None:
|
||||||
|
limit = (self.memory_bytes, self.memory_bytes)
|
||||||
|
if hasattr(resource, "RLIMIT_AS"):
|
||||||
|
resource.setrlimit(resource.RLIMIT_AS, limit)
|
||||||
|
elif hasattr(resource, "RLIMIT_DATA"):
|
||||||
|
resource.setrlimit(resource.RLIMIT_DATA, limit)
|
||||||
|
|
||||||
|
return _configure
|
||||||
|
|
||||||
|
def _apply_post_spawn_limits(self, process: subprocess.Popen[str]) -> None:
|
||||||
|
if not self._limits_requested or not self._can_use_prlimit():
|
||||||
|
return
|
||||||
|
if resource is None: # pragma: no cover - defensive
|
||||||
|
return
|
||||||
|
pid = process.pid
|
||||||
|
if pid is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
prlimit = typing.cast("typing.Any", resource).prlimit
|
||||||
|
if self.cpu_time_seconds is not None:
|
||||||
|
prlimit(pid, resource.RLIMIT_CPU, (self.cpu_time_seconds, self.cpu_time_seconds))
|
||||||
|
if self.memory_bytes is not None:
|
||||||
|
limit = (self.memory_bytes, self.memory_bytes)
|
||||||
|
if hasattr(resource, "RLIMIT_AS"):
|
||||||
|
prlimit(pid, resource.RLIMIT_AS, limit)
|
||||||
|
elif hasattr(resource, "RLIMIT_DATA"):
|
||||||
|
prlimit(pid, resource.RLIMIT_DATA, limit)
|
||||||
|
except OSError as exc: # pragma: no cover - depends on platform support
|
||||||
|
msg = "Failed to apply resource limits via prlimit."
|
||||||
|
raise RuntimeError(msg) from exc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _can_use_prlimit() -> bool:
|
||||||
|
return (
|
||||||
|
resource is not None
|
||||||
|
and hasattr(resource, "prlimit")
|
||||||
|
and sys.platform.startswith("linux")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
|
||||||
|
"""Launch the shell through the Codex CLI sandbox.
|
||||||
|
|
||||||
|
Ideal when you have the Codex CLI installed and want the additional syscall and
|
||||||
|
filesystem restrictions provided by Anthropic's Seatbelt (macOS) or Landlock/seccomp
|
||||||
|
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
|
||||||
|
the CLI. If the Codex binary is unavailable or the runtime lacks the required
|
||||||
|
kernel features (e.g., Landlock inside some containers), process startup fails with a
|
||||||
|
:class:`RuntimeError`.
|
||||||
|
|
||||||
|
Configure sandbox behaviour via ``config_overrides`` to align with your Codex CLI
|
||||||
|
profile. This policy does not add its own resource limits; combine it with
|
||||||
|
host-level guards (cgroups, container resource limits) as needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
binary: str = "codex"
|
||||||
|
platform: typing.Literal["auto", "macos", "linux"] = "auto"
|
||||||
|
config_overrides: Mapping[str, typing.Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def spawn(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
workspace: Path,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
command: Sequence[str],
|
||||||
|
) -> subprocess.Popen[str]:
|
||||||
|
full_command = self._build_command(command)
|
||||||
|
return _launch_subprocess(
|
||||||
|
full_command,
|
||||||
|
env=env,
|
||||||
|
cwd=workspace,
|
||||||
|
preexec_fn=None,
|
||||||
|
start_new_session=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_command(self, command: Sequence[str]) -> list[str]:
|
||||||
|
binary = self._resolve_binary()
|
||||||
|
platform_arg = self._determine_platform()
|
||||||
|
full_command: list[str] = [binary, "sandbox", platform_arg]
|
||||||
|
for key, value in sorted(dict(self.config_overrides).items()):
|
||||||
|
full_command.extend(["-c", f"{key}={self._format_override(value)}"])
|
||||||
|
full_command.append("--")
|
||||||
|
full_command.extend(command)
|
||||||
|
return full_command
|
||||||
|
|
||||||
|
def _resolve_binary(self) -> str:
|
||||||
|
path = shutil.which(self.binary)
|
||||||
|
if path is None:
|
||||||
|
msg = (
|
||||||
|
"Codex sandbox policy requires the '%s' CLI to be installed and available on PATH."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg % self.binary)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _determine_platform(self) -> str:
|
||||||
|
if self.platform != "auto":
|
||||||
|
return self.platform
|
||||||
|
if sys.platform.startswith("linux"):
|
||||||
|
return "linux"
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return "macos"
|
||||||
|
msg = (
|
||||||
|
"Codex sandbox policy could not determine a supported platform; "
|
||||||
|
"set 'platform' explicitly."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_override(value: typing.Any) -> str:
|
||||||
|
try:
|
||||||
|
return json.dumps(value)
|
||||||
|
except TypeError:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DockerExecutionPolicy(BaseExecutionPolicy):
|
||||||
|
"""Run the shell inside a dedicated Docker container.
|
||||||
|
|
||||||
|
Choose this policy when commands originate from untrusted users or you require
|
||||||
|
strong isolation between sessions. By default the workspace is bind-mounted only when
|
||||||
|
it refers to an existing non-temporary directory; ephemeral sessions run without a
|
||||||
|
mount to minimise host exposure. The container's network namespace is disabled by
|
||||||
|
default (``--network none``) and you can enable further hardening via
|
||||||
|
``read_only_rootfs`` and ``user``.
|
||||||
|
|
||||||
|
The security guarantees depend on your Docker daemon configuration. Run the agent on
|
||||||
|
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and review
|
||||||
|
any additional volumes or capabilities passed through ``extra_run_args``. The default
|
||||||
|
image is ``python:3.12-alpine3.19``; supply a custom image if you need preinstalled
|
||||||
|
tooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
binary: str = "docker"
|
||||||
|
image: str = "python:3.12-alpine3.19"
|
||||||
|
remove_container_on_exit: bool = True
|
||||||
|
network_enabled: bool = False
|
||||||
|
extra_run_args: Sequence[str] | None = None
|
||||||
|
memory_bytes: int | None = None
|
||||||
|
cpu_time_seconds: typing.Any | None = None
|
||||||
|
cpus: str | None = None
|
||||||
|
read_only_rootfs: bool = False
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
if self.memory_bytes is not None and self.memory_bytes <= 0:
|
||||||
|
msg = "memory_bytes must be positive if provided."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.cpu_time_seconds is not None:
|
||||||
|
msg = (
|
||||||
|
"DockerExecutionPolicy does not support cpu_time_seconds; configure CPU limits "
|
||||||
|
"using Docker run options such as '--cpus'."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
if self.cpus is not None and not self.cpus.strip():
|
||||||
|
msg = "cpus must be a non-empty string when provided."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.user is not None and not self.user.strip():
|
||||||
|
msg = "user must be a non-empty string when provided."
|
||||||
|
raise ValueError(msg)
|
||||||
|
self.extra_run_args = tuple(self.extra_run_args or ())
|
||||||
|
|
||||||
|
def spawn(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
workspace: Path,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
command: Sequence[str],
|
||||||
|
) -> subprocess.Popen[str]:
|
||||||
|
full_command = self._build_command(workspace, env, command)
|
||||||
|
host_env = os.environ.copy()
|
||||||
|
return _launch_subprocess(
|
||||||
|
full_command,
|
||||||
|
env=host_env,
|
||||||
|
cwd=workspace,
|
||||||
|
preexec_fn=None,
|
||||||
|
start_new_session=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_command(
|
||||||
|
self,
|
||||||
|
workspace: Path,
|
||||||
|
env: Mapping[str, str],
|
||||||
|
command: Sequence[str],
|
||||||
|
) -> list[str]:
|
||||||
|
binary = self._resolve_binary()
|
||||||
|
full_command: list[str] = [binary, "run", "-i"]
|
||||||
|
if self.remove_container_on_exit:
|
||||||
|
full_command.append("--rm")
|
||||||
|
if not self.network_enabled:
|
||||||
|
full_command.extend(["--network", "none"])
|
||||||
|
if self.memory_bytes is not None:
|
||||||
|
full_command.extend(["--memory", str(self.memory_bytes)])
|
||||||
|
if self._should_mount_workspace(workspace):
|
||||||
|
host_path = str(workspace)
|
||||||
|
full_command.extend(["-v", f"{host_path}:{host_path}"])
|
||||||
|
full_command.extend(["-w", host_path])
|
||||||
|
else:
|
||||||
|
full_command.extend(["-w", "/"])
|
||||||
|
if self.read_only_rootfs:
|
||||||
|
full_command.append("--read-only")
|
||||||
|
for key, value in env.items():
|
||||||
|
full_command.extend(["-e", f"{key}={value}"])
|
||||||
|
if self.cpus is not None:
|
||||||
|
full_command.extend(["--cpus", self.cpus])
|
||||||
|
if self.user is not None:
|
||||||
|
full_command.extend(["--user", self.user])
|
||||||
|
if self.extra_run_args:
|
||||||
|
full_command.extend(self.extra_run_args)
|
||||||
|
full_command.append(self.image)
|
||||||
|
full_command.extend(command)
|
||||||
|
return full_command
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_mount_workspace(workspace: Path) -> bool:
|
||||||
|
return not workspace.name.startswith(SHELL_TEMP_PREFIX)
|
||||||
|
|
||||||
|
def _resolve_binary(self) -> str:
|
||||||
|
path = shutil.which(self.binary)
|
||||||
|
if path is None:
|
||||||
|
msg = (
|
||||||
|
"Docker execution policy requires the '%s' CLI to be installed"
|
||||||
|
" and available on PATH."
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg % self.binary)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseExecutionPolicy",
|
||||||
|
"CodexSandboxExecutionPolicy",
|
||||||
|
"DockerExecutionPolicy",
|
||||||
|
"HostExecutionPolicy",
|
||||||
|
]
|
||||||
350
libs/langchain_v1/langchain/agents/middleware/_redaction.py
Normal file
350
libs/langchain_v1/langchain/agents/middleware/_redaction.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""Shared redaction utilities for middleware components."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import ipaddress
|
||||||
|
import re
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
RedactionStrategy = Literal["block", "redact", "mask", "hash"]
|
||||||
|
"""Supported strategies for handling detected sensitive values."""
|
||||||
|
|
||||||
|
|
||||||
|
class PIIMatch(TypedDict):
|
||||||
|
"""Represents an individual match of sensitive data."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
value: str
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
class PIIDetectionError(Exception):
|
||||||
|
"""Raised when configured to block on detected sensitive values."""
|
||||||
|
|
||||||
|
def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None:
|
||||||
|
"""Initialize the exception with match context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pii_type: Name of the detected sensitive type.
|
||||||
|
matches: All matches that were detected for that type.
|
||||||
|
"""
|
||||||
|
self.pii_type = pii_type
|
||||||
|
self.matches = list(matches)
|
||||||
|
count = len(matches)
|
||||||
|
msg = f"Detected {count} instance(s) of {pii_type} in text content"
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
Detector = Callable[[str], list[PIIMatch]]
|
||||||
|
"""Callable signature for detectors that locate sensitive values."""
|
||||||
|
|
||||||
|
|
||||||
|
def detect_email(content: str) -> list[PIIMatch]:
|
||||||
|
"""Detect email addresses in content."""
|
||||||
|
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
||||||
|
return [
|
||||||
|
PIIMatch(
|
||||||
|
type="email",
|
||||||
|
value=match.group(),
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
for match in re.finditer(pattern, content)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_credit_card(content: str) -> list[PIIMatch]:
|
||||||
|
"""Detect credit card numbers in content using Luhn validation."""
|
||||||
|
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for match in re.finditer(pattern, content):
|
||||||
|
card_number = match.group()
|
||||||
|
if _passes_luhn(card_number):
|
||||||
|
matches.append(
|
||||||
|
PIIMatch(
|
||||||
|
type="credit_card",
|
||||||
|
value=card_number,
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def detect_ip(content: str) -> list[PIIMatch]:
|
||||||
|
"""Detect IPv4 or IPv6 addresses in content."""
|
||||||
|
matches: list[PIIMatch] = []
|
||||||
|
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
|
||||||
|
|
||||||
|
for match in re.finditer(ipv4_pattern, content):
|
||||||
|
ip_candidate = match.group()
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(ip_candidate)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
matches.append(
|
||||||
|
PIIMatch(
|
||||||
|
type="ip",
|
||||||
|
value=ip_candidate,
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def detect_mac_address(content: str) -> list[PIIMatch]:
|
||||||
|
"""Detect MAC addresses in content."""
|
||||||
|
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
|
||||||
|
return [
|
||||||
|
PIIMatch(
|
||||||
|
type="mac_address",
|
||||||
|
value=match.group(),
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
for match in re.finditer(pattern, content)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_url(content: str) -> list[PIIMatch]:
|
||||||
|
"""Detect URLs in content using regex and stdlib validation."""
|
||||||
|
matches: list[PIIMatch] = []
|
||||||
|
|
||||||
|
# Pattern 1: URLs with scheme (http:// or https://)
|
||||||
|
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
|
||||||
|
|
||||||
|
for match in re.finditer(scheme_pattern, content):
|
||||||
|
url = match.group()
|
||||||
|
result = urlparse(url)
|
||||||
|
if result.scheme in ("http", "https") and result.netloc:
|
||||||
|
matches.append(
|
||||||
|
PIIMatch(
|
||||||
|
type="url",
|
||||||
|
value=url,
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
|
||||||
|
# More conservative to avoid false positives
|
||||||
|
bare_pattern = (
|
||||||
|
r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?"
|
||||||
|
r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?"
|
||||||
|
)
|
||||||
|
|
||||||
|
for match in re.finditer(bare_pattern, content):
|
||||||
|
start, end = match.start(), match.end()
|
||||||
|
# Skip if already matched with scheme
|
||||||
|
if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches):
|
||||||
|
continue
|
||||||
|
|
||||||
|
url = match.group()
|
||||||
|
# Only accept if it has a path or starts with www
|
||||||
|
# This reduces false positives like "example.com" in prose
|
||||||
|
if "/" in url or url.startswith("www."):
|
||||||
|
# Add scheme for validation (required for urlparse to work correctly)
|
||||||
|
test_url = f"http://{url}"
|
||||||
|
result = urlparse(test_url)
|
||||||
|
if result.netloc and "." in result.netloc:
|
||||||
|
matches.append(
|
||||||
|
PIIMatch(
|
||||||
|
type="url",
|
||||||
|
value=url,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
BUILTIN_DETECTORS: dict[str, Detector] = {
|
||||||
|
"email": detect_email,
|
||||||
|
"credit_card": detect_credit_card,
|
||||||
|
"ip": detect_ip,
|
||||||
|
"mac_address": detect_mac_address,
|
||||||
|
"url": detect_url,
|
||||||
|
}
|
||||||
|
"""Registry of built-in detectors keyed by type name."""
|
||||||
|
|
||||||
|
|
||||||
|
def _passes_luhn(card_number: str) -> bool:
|
||||||
|
"""Validate credit card number using the Luhn checksum."""
|
||||||
|
digits = [int(d) for d in card_number if d.isdigit()]
|
||||||
|
if not 13 <= len(digits) <= 19:
|
||||||
|
return False
|
||||||
|
|
||||||
|
checksum = 0
|
||||||
|
for index, digit in enumerate(reversed(digits)):
|
||||||
|
value = digit
|
||||||
|
if index % 2 == 1:
|
||||||
|
value *= 2
|
||||||
|
if value > 9:
|
||||||
|
value -= 9
|
||||||
|
checksum += value
|
||||||
|
return checksum % 10 == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||||
|
result = content
|
||||||
|
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
||||||
|
replacement = f"[REDACTED_{match['type'].upper()}]"
|
||||||
|
result = result[: match["start"]] + replacement + result[match["end"] :]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||||
|
result = content
|
||||||
|
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
||||||
|
value = match["value"]
|
||||||
|
pii_type = match["type"]
|
||||||
|
if pii_type == "email":
|
||||||
|
parts = value.split("@")
|
||||||
|
if len(parts) == 2:
|
||||||
|
domain_parts = parts[1].split(".")
|
||||||
|
masked = (
|
||||||
|
f"{parts[0]}@****.{domain_parts[-1]}"
|
||||||
|
if len(domain_parts) >= 2
|
||||||
|
else f"{parts[0]}@****"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
masked = "****"
|
||||||
|
elif pii_type == "credit_card":
|
||||||
|
digits_only = "".join(c for c in value if c.isdigit())
|
||||||
|
separator = "-" if "-" in value else " " if " " in value else ""
|
||||||
|
if separator:
|
||||||
|
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
|
||||||
|
else:
|
||||||
|
masked = f"************{digits_only[-4:]}"
|
||||||
|
elif pii_type == "ip":
|
||||||
|
octets = value.split(".")
|
||||||
|
masked = f"*.*.*.{octets[-1]}" if len(octets) == 4 else "****"
|
||||||
|
elif pii_type == "mac_address":
|
||||||
|
separator = ":" if ":" in value else "-"
|
||||||
|
masked = (
|
||||||
|
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
|
||||||
|
)
|
||||||
|
elif pii_type == "url":
|
||||||
|
masked = "[MASKED_URL]"
|
||||||
|
else:
|
||||||
|
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
|
||||||
|
result = result[: match["start"]] + masked + result[match["end"] :]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||||
|
result = content
|
||||||
|
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
||||||
|
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
|
||||||
|
replacement = f"<{match['type']}_hash:{digest}>"
|
||||||
|
result = result[: match["start"]] + replacement + result[match["end"] :]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_strategy(
|
||||||
|
content: str,
|
||||||
|
matches: list[PIIMatch],
|
||||||
|
strategy: RedactionStrategy,
|
||||||
|
) -> str:
|
||||||
|
"""Apply the configured strategy to matches within content."""
|
||||||
|
if not matches:
|
||||||
|
return content
|
||||||
|
if strategy == "redact":
|
||||||
|
return _apply_redact_strategy(content, matches)
|
||||||
|
if strategy == "mask":
|
||||||
|
return _apply_mask_strategy(content, matches)
|
||||||
|
if strategy == "hash":
|
||||||
|
return _apply_hash_strategy(content, matches)
|
||||||
|
if strategy == "block":
|
||||||
|
raise PIIDetectionError(matches[0]["type"], matches)
|
||||||
|
msg = f"Unknown redaction strategy: {strategy}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector:
|
||||||
|
"""Return a callable detector for the given configuration."""
|
||||||
|
if detector is None:
|
||||||
|
if pii_type not in BUILTIN_DETECTORS:
|
||||||
|
msg = (
|
||||||
|
f"Unknown PII type: {pii_type}. "
|
||||||
|
f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
return BUILTIN_DETECTORS[pii_type]
|
||||||
|
if isinstance(detector, str):
|
||||||
|
pattern = re.compile(detector)
|
||||||
|
|
||||||
|
def regex_detector(content: str) -> list[PIIMatch]:
|
||||||
|
return [
|
||||||
|
PIIMatch(
|
||||||
|
type=pii_type,
|
||||||
|
value=match.group(),
|
||||||
|
start=match.start(),
|
||||||
|
end=match.end(),
|
||||||
|
)
|
||||||
|
for match in pattern.finditer(content)
|
||||||
|
]
|
||||||
|
|
||||||
|
return regex_detector
|
||||||
|
return detector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RedactionRule:
|
||||||
|
"""Configuration for handling a single PII type."""
|
||||||
|
|
||||||
|
pii_type: str
|
||||||
|
strategy: RedactionStrategy = "redact"
|
||||||
|
detector: Detector | str | None = None
|
||||||
|
|
||||||
|
def resolve(self) -> ResolvedRedactionRule:
|
||||||
|
"""Resolve runtime detector and return an immutable rule."""
|
||||||
|
resolved_detector = resolve_detector(self.pii_type, self.detector)
|
||||||
|
return ResolvedRedactionRule(
|
||||||
|
pii_type=self.pii_type,
|
||||||
|
strategy=self.strategy,
|
||||||
|
detector=resolved_detector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedRedactionRule:
|
||||||
|
"""Resolved redaction rule ready for execution."""
|
||||||
|
|
||||||
|
pii_type: str
|
||||||
|
strategy: RedactionStrategy
|
||||||
|
detector: Detector
|
||||||
|
|
||||||
|
def apply(self, content: str) -> tuple[str, list[PIIMatch]]:
|
||||||
|
"""Apply this rule to content, returning new content and matches."""
|
||||||
|
matches = self.detector(content)
|
||||||
|
if not matches:
|
||||||
|
return content, []
|
||||||
|
updated = apply_strategy(content, matches, self.strategy)
|
||||||
|
return updated, matches
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PIIDetectionError",
|
||||||
|
"PIIMatch",
|
||||||
|
"RedactionRule",
|
||||||
|
"ResolvedRedactionRule",
|
||||||
|
"apply_strategy",
|
||||||
|
"detect_credit_card",
|
||||||
|
"detect_email",
|
||||||
|
"detect_ip",
|
||||||
|
"detect_mac_address",
|
||||||
|
"detect_url",
|
||||||
|
]
|
||||||
@@ -2,15 +2,22 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import ipaddress
|
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
from langchain.agents.middleware._redaction import (
|
||||||
|
PIIDetectionError,
|
||||||
|
PIIMatch,
|
||||||
|
RedactionRule,
|
||||||
|
ResolvedRedactionRule,
|
||||||
|
apply_strategy,
|
||||||
|
detect_credit_card,
|
||||||
|
detect_email,
|
||||||
|
detect_ip,
|
||||||
|
detect_mac_address,
|
||||||
|
detect_url,
|
||||||
|
)
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -19,396 +26,6 @@ if TYPE_CHECKING:
|
|||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
class PIIMatch(TypedDict):
|
|
||||||
"""Represents a detected PII match in text."""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
"""The type of PII detected (e.g., 'email', 'ssn', 'credit_card')."""
|
|
||||||
value: str
|
|
||||||
"""The actual matched text."""
|
|
||||||
start: int
|
|
||||||
"""Starting position of the match in the text."""
|
|
||||||
end: int
|
|
||||||
"""Ending position of the match in the text."""
|
|
||||||
|
|
||||||
|
|
||||||
class PIIDetectionError(Exception):
|
|
||||||
"""Exception raised when PII is detected and strategy is 'block'."""
|
|
||||||
|
|
||||||
def __init__(self, pii_type: str, matches: list[PIIMatch]) -> None:
|
|
||||||
"""Initialize the exception with PII detection information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pii_type: The type of PII that was detected.
|
|
||||||
matches: List of PII matches found.
|
|
||||||
"""
|
|
||||||
self.pii_type = pii_type
|
|
||||||
self.matches = matches
|
|
||||||
count = len(matches)
|
|
||||||
msg = f"Detected {count} instance(s) of {pii_type} in message content"
|
|
||||||
super().__init__(msg)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# PII Detection Functions
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _luhn_checksum(card_number: str) -> bool:
|
|
||||||
"""Validate credit card number using Luhn algorithm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
card_number: Credit card number string (digits only).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the number passes Luhn validation, False otherwise.
|
|
||||||
"""
|
|
||||||
digits = [int(d) for d in card_number if d.isdigit()]
|
|
||||||
|
|
||||||
if len(digits) < 13 or len(digits) > 19:
|
|
||||||
return False
|
|
||||||
|
|
||||||
checksum = 0
|
|
||||||
for i, digit in enumerate(reversed(digits)):
|
|
||||||
d = digit
|
|
||||||
if i % 2 == 1:
|
|
||||||
d *= 2
|
|
||||||
if d > 9:
|
|
||||||
d -= 9
|
|
||||||
checksum += d
|
|
||||||
|
|
||||||
return checksum % 10 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def detect_email(content: str) -> list[PIIMatch]:
|
|
||||||
"""Detect email addresses in content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Text content to scan.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detected email matches.
|
|
||||||
"""
|
|
||||||
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
|
||||||
return [
|
|
||||||
PIIMatch(
|
|
||||||
type="email",
|
|
||||||
value=match.group(),
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
for match in re.finditer(pattern, content)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_credit_card(content: str) -> list[PIIMatch]:
|
|
||||||
"""Detect credit card numbers in content using Luhn validation.
|
|
||||||
|
|
||||||
Detects cards in formats like:
|
|
||||||
- 1234567890123456
|
|
||||||
- 1234 5678 9012 3456
|
|
||||||
- 1234-5678-9012-3456
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Text content to scan.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detected credit card matches.
|
|
||||||
"""
|
|
||||||
# Match various credit card formats
|
|
||||||
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
for match in re.finditer(pattern, content):
|
|
||||||
card_number = match.group()
|
|
||||||
# Validate with Luhn algorithm
|
|
||||||
if _luhn_checksum(card_number):
|
|
||||||
matches.append(
|
|
||||||
PIIMatch(
|
|
||||||
type="credit_card",
|
|
||||||
value=card_number,
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
def detect_ip(content: str) -> list[PIIMatch]:
|
|
||||||
"""Detect IP addresses in content using stdlib validation.
|
|
||||||
|
|
||||||
Validates both IPv4 and IPv6 addresses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Text content to scan.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detected IP address matches.
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
# IPv4 pattern
|
|
||||||
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
|
|
||||||
|
|
||||||
for match in re.finditer(ipv4_pattern, content):
|
|
||||||
ip_str = match.group()
|
|
||||||
try:
|
|
||||||
# Validate with stdlib
|
|
||||||
ipaddress.ip_address(ip_str)
|
|
||||||
matches.append(
|
|
||||||
PIIMatch(
|
|
||||||
type="ip",
|
|
||||||
value=ip_str,
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# Not a valid IP address
|
|
||||||
pass
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
def detect_mac_address(content: str) -> list[PIIMatch]:
|
|
||||||
"""Detect MAC addresses in content.
|
|
||||||
|
|
||||||
Detects formats like:
|
|
||||||
- 00:1A:2B:3C:4D:5E
|
|
||||||
- 00-1A-2B-3C-4D-5E
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Text content to scan.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detected MAC address matches.
|
|
||||||
"""
|
|
||||||
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
|
|
||||||
return [
|
|
||||||
PIIMatch(
|
|
||||||
type="mac_address",
|
|
||||||
value=match.group(),
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
for match in re.finditer(pattern, content)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_url(content: str) -> list[PIIMatch]:
|
|
||||||
"""Detect URLs in content using regex and stdlib validation.
|
|
||||||
|
|
||||||
Detects:
|
|
||||||
- http://example.com
|
|
||||||
- https://example.com/path
|
|
||||||
- www.example.com
|
|
||||||
- example.com/path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Text content to scan.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detected URL matches.
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
# Pattern 1: URLs with scheme (http:// or https://)
|
|
||||||
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
|
|
||||||
|
|
||||||
for match in re.finditer(scheme_pattern, content):
|
|
||||||
url = match.group()
|
|
||||||
try:
|
|
||||||
result = urlparse(url)
|
|
||||||
if result.scheme in ("http", "https") and result.netloc:
|
|
||||||
matches.append(
|
|
||||||
PIIMatch(
|
|
||||||
type="url",
|
|
||||||
value=url,
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception: # noqa: S110, BLE001
|
|
||||||
# Invalid URL, skip
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
|
|
||||||
# More conservative to avoid false positives
|
|
||||||
bare_pattern = r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?" # noqa: E501
|
|
||||||
|
|
||||||
for match in re.finditer(bare_pattern, content):
|
|
||||||
# Skip if already matched with scheme
|
|
||||||
if any(
|
|
||||||
m["start"] <= match.start() < m["end"] or m["start"] < match.end() <= m["end"]
|
|
||||||
for m in matches
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
url = match.group()
|
|
||||||
# Only accept if it has a path or starts with www
|
|
||||||
# This reduces false positives like "example.com" in prose
|
|
||||||
if "/" in url or url.startswith("www."):
|
|
||||||
try:
|
|
||||||
# Add scheme for validation (required for urlparse to work correctly)
|
|
||||||
test_url = f"http://{url}"
|
|
||||||
result = urlparse(test_url)
|
|
||||||
if result.netloc and "." in result.netloc:
|
|
||||||
matches.append(
|
|
||||||
PIIMatch(
|
|
||||||
type="url",
|
|
||||||
value=url,
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception: # noqa: S110, BLE001
|
|
||||||
# Invalid URL, skip
|
|
||||||
pass
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
# Built-in detector registry
|
|
||||||
_BUILTIN_DETECTORS: dict[str, Callable[[str], list[PIIMatch]]] = {
|
|
||||||
"email": detect_email,
|
|
||||||
"credit_card": detect_credit_card,
|
|
||||||
"ip": detect_ip,
|
|
||||||
"mac_address": detect_mac_address,
|
|
||||||
"url": detect_url,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Strategy Implementations
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
||||||
"""Replace PII with [REDACTED_TYPE] placeholders.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Original content.
|
|
||||||
matches: List of PII matches to redact.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content with PII redacted.
|
|
||||||
"""
|
|
||||||
if not matches:
|
|
||||||
return content
|
|
||||||
|
|
||||||
# Sort matches by start position in reverse to avoid offset issues
|
|
||||||
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
||||||
|
|
||||||
result = content
|
|
||||||
for match in sorted_matches:
|
|
||||||
replacement = f"[REDACTED_{match['type'].upper()}]"
|
|
||||||
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
||||||
"""Partially mask PII, showing only last few characters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Original content.
|
|
||||||
matches: List of PII matches to mask.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content with PII masked.
|
|
||||||
"""
|
|
||||||
if not matches:
|
|
||||||
return content
|
|
||||||
|
|
||||||
# Sort matches by start position in reverse
|
|
||||||
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
||||||
|
|
||||||
result = content
|
|
||||||
for match in sorted_matches:
|
|
||||||
value = match["value"]
|
|
||||||
pii_type = match["type"]
|
|
||||||
|
|
||||||
# Different masking strategies by type
|
|
||||||
if pii_type == "email":
|
|
||||||
# Show only domain: user@****.com
|
|
||||||
parts = value.split("@")
|
|
||||||
if len(parts) == 2:
|
|
||||||
domain_parts = parts[1].split(".")
|
|
||||||
if len(domain_parts) >= 2:
|
|
||||||
masked = f"{parts[0]}@****.{domain_parts[-1]}"
|
|
||||||
else:
|
|
||||||
masked = f"{parts[0]}@****"
|
|
||||||
else:
|
|
||||||
masked = "****"
|
|
||||||
|
|
||||||
elif pii_type == "credit_card":
|
|
||||||
# Show last 4: ****-****-****-1234
|
|
||||||
digits_only = "".join(c for c in value if c.isdigit())
|
|
||||||
separator = "-" if "-" in value else " " if " " in value else ""
|
|
||||||
if separator:
|
|
||||||
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
|
|
||||||
else:
|
|
||||||
masked = f"************{digits_only[-4:]}"
|
|
||||||
|
|
||||||
elif pii_type == "ip":
|
|
||||||
# Show last octet: *.*.*. 123
|
|
||||||
parts = value.split(".")
|
|
||||||
masked = f"*.*.*.{parts[-1]}" if len(parts) == 4 else "****"
|
|
||||||
|
|
||||||
elif pii_type == "mac_address":
|
|
||||||
# Show last byte: **:**:**:**:**:5E
|
|
||||||
separator = ":" if ":" in value else "-"
|
|
||||||
masked = (
|
|
||||||
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif pii_type == "url":
|
|
||||||
# Mask everything: [MASKED_URL]
|
|
||||||
masked = "[MASKED_URL]"
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Default: show last 4 chars
|
|
||||||
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
|
|
||||||
|
|
||||||
result = result[: match["start"]] + masked + result[match["end"] :]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
||||||
"""Replace PII with deterministic hash including type information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Original content.
|
|
||||||
matches: List of PII matches to hash.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content with PII replaced by hashes in format <type_hash:digest>.
|
|
||||||
"""
|
|
||||||
if not matches:
|
|
||||||
return content
|
|
||||||
|
|
||||||
# Sort matches by start position in reverse
|
|
||||||
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
||||||
|
|
||||||
result = content
|
|
||||||
for match in sorted_matches:
|
|
||||||
value = match["value"]
|
|
||||||
pii_type = match["type"]
|
|
||||||
# Create deterministic hash
|
|
||||||
hash_digest = hashlib.sha256(value.encode()).hexdigest()[:8]
|
|
||||||
replacement = f"<{pii_type}_hash:{hash_digest}>"
|
|
||||||
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# PIIMiddleware
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class PIIMiddleware(AgentMiddleware):
|
class PIIMiddleware(AgentMiddleware):
|
||||||
"""Detect and handle Personally Identifiable Information (PII) in agent conversations.
|
"""Detect and handle Personally Identifiable Information (PII) in agent conversations.
|
||||||
|
|
||||||
@@ -510,50 +127,34 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pii_type = pii_type
|
|
||||||
self.strategy = strategy
|
|
||||||
self.apply_to_input = apply_to_input
|
self.apply_to_input = apply_to_input
|
||||||
self.apply_to_output = apply_to_output
|
self.apply_to_output = apply_to_output
|
||||||
self.apply_to_tool_results = apply_to_tool_results
|
self.apply_to_tool_results = apply_to_tool_results
|
||||||
|
|
||||||
# Resolve detector
|
self._resolved_rule: ResolvedRedactionRule = RedactionRule(
|
||||||
if detector is None:
|
pii_type=pii_type,
|
||||||
# Use built-in detector
|
strategy=strategy,
|
||||||
if pii_type not in _BUILTIN_DETECTORS:
|
detector=detector,
|
||||||
msg = (
|
).resolve()
|
||||||
f"Unknown PII type: {pii_type}. "
|
self.pii_type = self._resolved_rule.pii_type
|
||||||
f"Must be one of {list(_BUILTIN_DETECTORS.keys())} "
|
self.strategy = self._resolved_rule.strategy
|
||||||
"or provide a custom detector."
|
self.detector = self._resolved_rule.detector
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
self.detector = _BUILTIN_DETECTORS[pii_type]
|
|
||||||
elif isinstance(detector, str):
|
|
||||||
# Custom regex pattern
|
|
||||||
pattern = detector
|
|
||||||
|
|
||||||
def regex_detector(content: str) -> list[PIIMatch]:
|
|
||||||
return [
|
|
||||||
PIIMatch(
|
|
||||||
type=pii_type,
|
|
||||||
value=match.group(),
|
|
||||||
start=match.start(),
|
|
||||||
end=match.end(),
|
|
||||||
)
|
|
||||||
for match in re.finditer(pattern, content)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.detector = regex_detector
|
|
||||||
else:
|
|
||||||
# Custom callable detector
|
|
||||||
self.detector = detector
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Name of the middleware."""
|
"""Name of the middleware."""
|
||||||
return f"{self.__class__.__name__}[{self.pii_type}]"
|
return f"{self.__class__.__name__}[{self.pii_type}]"
|
||||||
|
|
||||||
|
def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
|
||||||
|
"""Apply the configured redaction rule to the provided content."""
|
||||||
|
matches = self.detector(content)
|
||||||
|
if not matches:
|
||||||
|
return content, []
|
||||||
|
sanitized = apply_strategy(content, matches, self.strategy)
|
||||||
|
return sanitized, matches
|
||||||
|
|
||||||
@hook_config(can_jump_to=["end"])
|
@hook_config(can_jump_to=["end"])
|
||||||
def before_model( # noqa: PLR0915
|
def before_model(
|
||||||
self,
|
self,
|
||||||
state: AgentState,
|
state: AgentState,
|
||||||
runtime: Runtime, # noqa: ARG002
|
runtime: Runtime, # noqa: ARG002
|
||||||
@@ -594,25 +195,9 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
if last_user_idx is not None and last_user_msg and last_user_msg.content:
|
if last_user_idx is not None and last_user_msg and last_user_msg.content:
|
||||||
# Detect PII in message content
|
# Detect PII in message content
|
||||||
content = str(last_user_msg.content)
|
content = str(last_user_msg.content)
|
||||||
matches = self.detector(content)
|
new_content, matches = self._process_content(content)
|
||||||
|
|
||||||
if matches:
|
if matches:
|
||||||
# Apply strategy
|
|
||||||
if self.strategy == "block":
|
|
||||||
raise PIIDetectionError(self.pii_type, matches)
|
|
||||||
|
|
||||||
if self.strategy == "redact":
|
|
||||||
new_content = _apply_redact_strategy(content, matches)
|
|
||||||
elif self.strategy == "mask":
|
|
||||||
new_content = _apply_mask_strategy(content, matches)
|
|
||||||
elif self.strategy == "hash":
|
|
||||||
new_content = _apply_hash_strategy(content, matches)
|
|
||||||
else:
|
|
||||||
# Should not reach here due to type hints
|
|
||||||
msg = f"Unknown strategy: {self.strategy}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Create updated message
|
|
||||||
updated_message: AnyMessage = HumanMessage(
|
updated_message: AnyMessage = HumanMessage(
|
||||||
content=new_content,
|
content=new_content,
|
||||||
id=last_user_msg.id,
|
id=last_user_msg.id,
|
||||||
@@ -641,26 +226,11 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
content = str(tool_msg.content)
|
content = str(tool_msg.content)
|
||||||
matches = self.detector(content)
|
new_content, matches = self._process_content(content)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Apply strategy
|
|
||||||
if self.strategy == "block":
|
|
||||||
raise PIIDetectionError(self.pii_type, matches)
|
|
||||||
|
|
||||||
if self.strategy == "redact":
|
|
||||||
new_content = _apply_redact_strategy(content, matches)
|
|
||||||
elif self.strategy == "mask":
|
|
||||||
new_content = _apply_mask_strategy(content, matches)
|
|
||||||
elif self.strategy == "hash":
|
|
||||||
new_content = _apply_hash_strategy(content, matches)
|
|
||||||
else:
|
|
||||||
# Should not reach here due to type hints
|
|
||||||
msg = f"Unknown strategy: {self.strategy}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Create updated tool message
|
# Create updated tool message
|
||||||
updated_message = ToolMessage(
|
updated_message = ToolMessage(
|
||||||
content=new_content,
|
content=new_content,
|
||||||
@@ -716,26 +286,11 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
# Detect PII in message content
|
# Detect PII in message content
|
||||||
content = str(last_ai_msg.content)
|
content = str(last_ai_msg.content)
|
||||||
matches = self.detector(content)
|
new_content, matches = self._process_content(content)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Apply strategy
|
|
||||||
if self.strategy == "block":
|
|
||||||
raise PIIDetectionError(self.pii_type, matches)
|
|
||||||
|
|
||||||
if self.strategy == "redact":
|
|
||||||
new_content = _apply_redact_strategy(content, matches)
|
|
||||||
elif self.strategy == "mask":
|
|
||||||
new_content = _apply_mask_strategy(content, matches)
|
|
||||||
elif self.strategy == "hash":
|
|
||||||
new_content = _apply_hash_strategy(content, matches)
|
|
||||||
else:
|
|
||||||
# Should not reach here due to type hints
|
|
||||||
msg = f"Unknown strategy: {self.strategy}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Create updated message
|
# Create updated message
|
||||||
updated_message = AIMessage(
|
updated_message = AIMessage(
|
||||||
content=new_content,
|
content=new_content,
|
||||||
@@ -749,3 +304,14 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
new_messages[last_ai_idx] = updated_message
|
new_messages[last_ai_idx] = updated_message
|
||||||
|
|
||||||
return {"messages": new_messages}
|
return {"messages": new_messages}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PIIDetectionError",
|
||||||
|
"PIIMiddleware",
|
||||||
|
"detect_credit_card",
|
||||||
|
"detect_email",
|
||||||
|
"detect_ip",
|
||||||
|
"detect_mac_address",
|
||||||
|
"detect_url",
|
||||||
|
]
|
||||||
|
|||||||
715
libs/langchain_v1/langchain/agents/middleware/shell_tool.py
Normal file
715
libs/langchain_v1/langchain/agents/middleware/shell_tool.py
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
"""Middleware that exposes a persistent shell tool to agents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
import uuid
|
||||||
|
import weakref
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools.base import BaseTool, ToolException
|
||||||
|
from langgraph.channels.untracked_value import UntrackedValue
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
from langchain.agents.middleware._execution import (
|
||||||
|
SHELL_TEMP_PREFIX,
|
||||||
|
BaseExecutionPolicy,
|
||||||
|
CodexSandboxExecutionPolicy,
|
||||||
|
DockerExecutionPolicy,
|
||||||
|
HostExecutionPolicy,
|
||||||
|
)
|
||||||
|
from langchain.agents.middleware._redaction import (
|
||||||
|
PIIDetectionError,
|
||||||
|
PIIMatch,
|
||||||
|
RedactionRule,
|
||||||
|
ResolvedRedactionRule,
|
||||||
|
)
|
||||||
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from langchain.tools.tool_node import ToolCallRequest
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||||
|
|
||||||
|
DEFAULT_TOOL_DESCRIPTION = (
|
||||||
|
"Execute a shell command inside a persistent session. Before running a command, "
|
||||||
|
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
|
||||||
|
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
|
||||||
|
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
|
||||||
|
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
|
||||||
|
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||||
|
"running commands will be terminated once their configured timeout elapses."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_resources(
|
||||||
|
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
|
||||||
|
) -> None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.stop(timeout)
|
||||||
|
if tempdir is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
tempdir.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SessionResources:
|
||||||
|
"""Container for per-run shell resources."""
|
||||||
|
|
||||||
|
session: ShellSession
|
||||||
|
tempdir: tempfile.TemporaryDirectory[str] | None
|
||||||
|
policy: BaseExecutionPolicy
|
||||||
|
_finalizer: weakref.finalize = field(init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self._finalizer = weakref.finalize(
|
||||||
|
self,
|
||||||
|
_cleanup_resources,
|
||||||
|
self.session,
|
||||||
|
self.tempdir,
|
||||||
|
self.policy.termination_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShellToolState(AgentState):
|
||||||
|
"""Agent state extension for tracking shell session resources."""
|
||||||
|
|
||||||
|
shell_session_resources: NotRequired[
|
||||||
|
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CommandExecutionResult:
|
||||||
|
"""Structured result from command execution."""
|
||||||
|
|
||||||
|
output: str
|
||||||
|
exit_code: int | None
|
||||||
|
timed_out: bool
|
||||||
|
truncated_by_lines: bool
|
||||||
|
truncated_by_bytes: bool
|
||||||
|
total_lines: int
|
||||||
|
total_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
class ShellSession:
|
||||||
|
"""Persistent shell session that supports sequential command execution."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace: Path,
|
||||||
|
policy: BaseExecutionPolicy,
|
||||||
|
command: tuple[str, ...],
|
||||||
|
environment: Mapping[str, str],
|
||||||
|
) -> None:
|
||||||
|
self._workspace = workspace
|
||||||
|
self._policy = policy
|
||||||
|
self._command = command
|
||||||
|
self._environment = dict(environment)
|
||||||
|
self._process: subprocess.Popen[str] | None = None
|
||||||
|
self._stdin: Any = None
|
||||||
|
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._stdout_thread: threading.Thread | None = None
|
||||||
|
self._stderr_thread: threading.Thread | None = None
|
||||||
|
self._terminated = False
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the shell subprocess and reader threads."""
|
||||||
|
if self._process and self._process.poll() is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process = self._policy.spawn(
|
||||||
|
workspace=self._workspace,
|
||||||
|
env=self._environment,
|
||||||
|
command=self._command,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self._process.stdin is None
|
||||||
|
or self._process.stdout is None
|
||||||
|
or self._process.stderr is None
|
||||||
|
):
|
||||||
|
msg = "Failed to initialize shell session pipes."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
self._stdin = self._process.stdin
|
||||||
|
self._terminated = False
|
||||||
|
self._queue = queue.Queue()
|
||||||
|
|
||||||
|
self._stdout_thread = threading.Thread(
|
||||||
|
target=self._enqueue_stream,
|
||||||
|
args=(self._process.stdout, "stdout"),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._stderr_thread = threading.Thread(
|
||||||
|
target=self._enqueue_stream,
|
||||||
|
args=(self._process.stderr, "stderr"),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._stdout_thread.start()
|
||||||
|
self._stderr_thread.start()
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
"""Restart the shell process."""
|
||||||
|
self.stop(self._policy.termination_timeout)
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def stop(self, timeout: float) -> None:
|
||||||
|
"""Stop the shell subprocess."""
|
||||||
|
if not self._process:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._process.poll() is None and not self._terminated:
|
||||||
|
try:
|
||||||
|
self._stdin.write("exit\n")
|
||||||
|
self._stdin.flush()
|
||||||
|
except (BrokenPipeError, OSError):
|
||||||
|
LOGGER.debug(
|
||||||
|
"Failed to write exit command; terminating shell session.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._process.wait(timeout=timeout) is None:
|
||||||
|
self._kill_process()
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self._kill_process()
|
||||||
|
finally:
|
||||||
|
self._terminated = True
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._stdin.close()
|
||||||
|
self._process = None
|
||||||
|
|
||||||
|
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
|
||||||
|
"""Execute a command in the persistent shell."""
|
||||||
|
if not self._process or self._process.poll() is not None:
|
||||||
|
msg = "Shell session is not running."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._drain_queue()
|
||||||
|
payload = command if command.endswith("\n") else f"{command}\n"
|
||||||
|
self._stdin.write(payload)
|
||||||
|
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
||||||
|
self._stdin.flush()
|
||||||
|
|
||||||
|
return self._collect_output(marker, deadline, timeout)
|
||||||
|
|
||||||
|
def _collect_output(
|
||||||
|
self,
|
||||||
|
marker: str,
|
||||||
|
deadline: float,
|
||||||
|
timeout: float,
|
||||||
|
) -> CommandExecutionResult:
|
||||||
|
collected: list[str] = []
|
||||||
|
total_lines = 0
|
||||||
|
total_bytes = 0
|
||||||
|
truncated_by_lines = False
|
||||||
|
truncated_by_bytes = False
|
||||||
|
exit_code: int | None = None
|
||||||
|
timed_out = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
remaining = deadline - time.monotonic()
|
||||||
|
if remaining <= 0:
|
||||||
|
timed_out = True
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
source, data = self._queue.get(timeout=remaining)
|
||||||
|
except queue.Empty:
|
||||||
|
timed_out = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source == "stdout" and data.startswith(marker):
|
||||||
|
_, _, status = data.partition(" ")
|
||||||
|
exit_code = self._safe_int(status.strip())
|
||||||
|
break
|
||||||
|
|
||||||
|
total_lines += 1
|
||||||
|
encoded = data.encode("utf-8", "replace")
|
||||||
|
total_bytes += len(encoded)
|
||||||
|
|
||||||
|
if total_lines > self._policy.max_output_lines:
|
||||||
|
truncated_by_lines = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._policy.max_output_bytes is not None
|
||||||
|
and total_bytes > self._policy.max_output_bytes
|
||||||
|
):
|
||||||
|
truncated_by_bytes = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source == "stderr":
|
||||||
|
stripped = data.rstrip("\n")
|
||||||
|
collected.append(f"[stderr] {stripped}")
|
||||||
|
if data.endswith("\n"):
|
||||||
|
collected.append("\n")
|
||||||
|
else:
|
||||||
|
collected.append(data)
|
||||||
|
|
||||||
|
if timed_out:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Command timed out after %.2f seconds; restarting shell session.",
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
self.restart()
|
||||||
|
return CommandExecutionResult(
|
||||||
|
output="",
|
||||||
|
exit_code=None,
|
||||||
|
timed_out=True,
|
||||||
|
truncated_by_lines=truncated_by_lines,
|
||||||
|
truncated_by_bytes=truncated_by_bytes,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = "".join(collected)
|
||||||
|
return CommandExecutionResult(
|
||||||
|
output=output,
|
||||||
|
exit_code=exit_code,
|
||||||
|
timed_out=False,
|
||||||
|
truncated_by_lines=truncated_by_lines,
|
||||||
|
truncated_by_bytes=truncated_by_bytes,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _kill_process(self) -> None:
|
||||||
|
if not self._process:
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(os, "killpg"):
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
|
||||||
|
else: # pragma: no cover
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
self._process.kill()
|
||||||
|
|
||||||
|
def _enqueue_stream(self, stream: Any, label: str) -> None:
|
||||||
|
for line in iter(stream.readline, ""):
|
||||||
|
self._queue.put((label, line))
|
||||||
|
self._queue.put((label, None))
|
||||||
|
|
||||||
|
def _drain_queue(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_int(value: str) -> int | None:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
return int(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _ShellToolInput(BaseModel):
|
||||||
|
"""Input schema for the persistent shell tool."""
|
||||||
|
|
||||||
|
command: str | None = None
|
||||||
|
restart: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_payload(self) -> _ShellToolInput:
|
||||||
|
if self.command is None and not self.restart:
|
||||||
|
msg = "Shell tool requires either 'command' or 'restart'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if self.command is not None and self.restart:
|
||||||
|
msg = "Specify only one of 'command' or 'restart'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _PersistentShellTool(BaseTool):
|
||||||
|
"""Tool wrapper that relies on middleware interception for execution."""
|
||||||
|
|
||||||
|
name: str = "shell"
|
||||||
|
description: str = DEFAULT_TOOL_DESCRIPTION
|
||||||
|
args_schema: type[BaseModel] = _ShellToolInput
|
||||||
|
|
||||||
|
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._middleware = middleware
|
||||||
|
if description is not None:
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
||||||
|
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||||
|
"""Middleware that registers a persistent shell tool for agents.
|
||||||
|
|
||||||
|
The middleware exposes a single long-lived shell session. Use the execution policy to
|
||||||
|
match your deployment's security posture:
|
||||||
|
|
||||||
|
* ``HostExecutionPolicy`` - full host access; best for trusted environments where the
|
||||||
|
agent already runs inside a container or VM that provides isolation.
|
||||||
|
* ``CodexSandboxExecutionPolicy`` - reuses the Codex CLI sandbox for additional
|
||||||
|
syscall/filesystem restrictions when the CLI is available.
|
||||||
|
* ``DockerExecutionPolicy`` - launches a separate Docker container for each agent run,
|
||||||
|
providing harder isolation, optional read-only root filesystems, and user remapping.
|
||||||
|
|
||||||
|
When no policy is provided the middleware defaults to ``HostExecutionPolicy``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_root: str | Path | None = None,
|
||||||
|
*,
|
||||||
|
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||||
|
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||||
|
execution_policy: BaseExecutionPolicy | None = None,
|
||||||
|
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||||
|
tool_description: str | None = None,
|
||||||
|
shell_command: Sequence[str] | str | None = None,
|
||||||
|
env: Mapping[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Base directory for the shell session. If omitted, a temporary
|
||||||
|
directory is created when the agent starts and removed when it ends.
|
||||||
|
startup_commands: Optional commands executed sequentially after the session starts.
|
||||||
|
shutdown_commands: Optional commands executed before the session shuts down.
|
||||||
|
execution_policy: Execution policy controlling timeouts, output limits, and resource
|
||||||
|
configuration. Defaults to :class:`HostExecutionPolicy` for native execution.
|
||||||
|
redaction_rules: Optional redaction rules to sanitize command output before
|
||||||
|
returning it to the model.
|
||||||
|
tool_description: Optional override for the registered shell tool description.
|
||||||
|
shell_command: Optional shell executable (string) or argument sequence used to
|
||||||
|
launch the persistent session. Defaults to an implementation-defined bash command.
|
||||||
|
env: Optional environment variables to supply to the shell session. Values are
|
||||||
|
coerced to strings before command execution. If omitted, the session inherits the
|
||||||
|
parent process environment.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||||
|
self._shell_command = self._normalize_shell_command(shell_command)
|
||||||
|
self._environment = self._normalize_env(env)
|
||||||
|
if execution_policy is not None:
|
||||||
|
self._execution_policy = execution_policy
|
||||||
|
else:
|
||||||
|
self._execution_policy = HostExecutionPolicy()
|
||||||
|
rules = redaction_rules or ()
|
||||||
|
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
|
||||||
|
rule.resolve() for rule in rules
|
||||||
|
)
|
||||||
|
self._startup_commands = self._normalize_commands(startup_commands)
|
||||||
|
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||||
|
|
||||||
|
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||||
|
self._tool = _PersistentShellTool(self, description=description)
|
||||||
|
self.tools = [self._tool]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_commands(
|
||||||
|
commands: tuple[str, ...] | list[str] | str | None,
|
||||||
|
) -> tuple[str, ...]:
|
||||||
|
if commands is None:
|
||||||
|
return ()
|
||||||
|
if isinstance(commands, str):
|
||||||
|
return (commands,)
|
||||||
|
return tuple(commands)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_shell_command(
|
||||||
|
shell_command: Sequence[str] | str | None,
|
||||||
|
) -> tuple[str, ...]:
|
||||||
|
if shell_command is None:
|
||||||
|
return ("/bin/bash",)
|
||||||
|
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
|
||||||
|
if not normalized:
|
||||||
|
msg = "Shell command must contain at least one argument."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
|
||||||
|
if env is None:
|
||||||
|
return None
|
||||||
|
normalized: dict[str, str] = {}
|
||||||
|
for key, value in env.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
msg = "Environment variable names must be strings."
|
||||||
|
raise TypeError(msg)
|
||||||
|
normalized[key] = str(value)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def before_agent(self, _state: ShellToolState, _runtime: Any) -> dict[str, Any] | None:
|
||||||
|
"""Start the shell session and run startup commands."""
|
||||||
|
resources = self._create_resources()
|
||||||
|
return {"shell_session_resources": resources}
|
||||||
|
|
||||||
|
async def abefore_agent(self, state: ShellToolState, _runtime: Any) -> dict[str, Any] | None:
|
||||||
|
"""Async counterpart to `before_agent`."""
|
||||||
|
return self.before_agent(state, _runtime)
|
||||||
|
|
||||||
|
def after_agent(self, state: ShellToolState, _runtime: Any) -> None:
|
||||||
|
"""Run shutdown commands and release resources when an agent completes."""
|
||||||
|
resources = self._ensure_resources(state)
|
||||||
|
try:
|
||||||
|
self._run_shutdown_commands(resources.session)
|
||||||
|
finally:
|
||||||
|
resources._finalizer()
|
||||||
|
|
||||||
|
async def aafter_agent(self, state: ShellToolState, _runtime: Any) -> None:
|
||||||
|
"""Async counterpart to `after_agent`."""
|
||||||
|
return self.after_agent(state, _runtime)
|
||||||
|
|
||||||
|
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||||
|
resources = state.get("shell_session_resources")
|
||||||
|
if resources is not None and not isinstance(resources, _SessionResources):
|
||||||
|
resources = None
|
||||||
|
if resources is None:
|
||||||
|
msg = (
|
||||||
|
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
||||||
|
"before invoking the shell tool."
|
||||||
|
)
|
||||||
|
raise ToolException(msg)
|
||||||
|
return resources
|
||||||
|
|
||||||
|
def _create_resources(self) -> _SessionResources:
|
||||||
|
workspace = self._workspace_root
|
||||||
|
tempdir: tempfile.TemporaryDirectory[str] | None = None
|
||||||
|
if workspace is None:
|
||||||
|
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
|
||||||
|
workspace_path = Path(tempdir.name)
|
||||||
|
else:
|
||||||
|
workspace_path = workspace
|
||||||
|
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
session = ShellSession(
|
||||||
|
workspace_path,
|
||||||
|
self._execution_policy,
|
||||||
|
self._shell_command,
|
||||||
|
self._environment or {},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
session.start()
|
||||||
|
LOGGER.info("Started shell session in %s", workspace_path)
|
||||||
|
self._run_startup_commands(session)
|
||||||
|
except BaseException:
|
||||||
|
LOGGER.exception("Starting shell session failed; cleaning up resources.")
|
||||||
|
session.stop(self._execution_policy.termination_timeout)
|
||||||
|
if tempdir is not None:
|
||||||
|
tempdir.cleanup()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
|
||||||
|
|
||||||
|
def _run_startup_commands(self, session: ShellSession) -> None:
|
||||||
|
if not self._startup_commands:
|
||||||
|
return
|
||||||
|
for command in self._startup_commands:
|
||||||
|
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
||||||
|
if result.timed_out or (result.exit_code not in (0, None)):
|
||||||
|
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
def _run_shutdown_commands(self, session: ShellSession) -> None:
|
||||||
|
if not self._shutdown_commands:
|
||||||
|
return
|
||||||
|
for command in self._shutdown_commands:
|
||||||
|
try:
|
||||||
|
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||||||
|
if result.timed_out:
|
||||||
|
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
||||||
|
elif result.exit_code not in (0, None):
|
||||||
|
LOGGER.warning(
|
||||||
|
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
||||||
|
)
|
||||||
|
except (RuntimeError, ToolException, OSError) as exc:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
|
||||||
|
"""Apply configured redaction rules to command output."""
|
||||||
|
matches_by_type: dict[str, list[PIIMatch]] = {}
|
||||||
|
updated = content
|
||||||
|
for rule in self._redaction_rules:
|
||||||
|
updated, matches = rule.apply(updated)
|
||||||
|
if matches:
|
||||||
|
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
|
||||||
|
return updated, matches_by_type
|
||||||
|
|
||||||
|
def _run_shell_tool(
|
||||||
|
self,
|
||||||
|
resources: _SessionResources,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
tool_call_id: str | None,
|
||||||
|
) -> Any:
|
||||||
|
session = resources.session
|
||||||
|
|
||||||
|
if payload.get("restart"):
|
||||||
|
LOGGER.info("Restarting shell session on request.")
|
||||||
|
try:
|
||||||
|
session.restart()
|
||||||
|
self._run_startup_commands(session)
|
||||||
|
except BaseException as err:
|
||||||
|
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
|
||||||
|
msg = "Failed to restart shell session."
|
||||||
|
raise ToolException(msg) from err
|
||||||
|
message = "Shell session restarted."
|
||||||
|
return self._format_tool_message(message, tool_call_id, status="success")
|
||||||
|
|
||||||
|
command = payload.get("command")
|
||||||
|
if not command or not isinstance(command, str):
|
||||||
|
msg = "Shell tool expects a 'command' string when restart is not requested."
|
||||||
|
raise ToolException(msg)
|
||||||
|
|
||||||
|
LOGGER.info("Executing shell command: %s", command)
|
||||||
|
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||||||
|
|
||||||
|
if result.timed_out:
|
||||||
|
timeout_seconds = self._execution_policy.command_timeout
|
||||||
|
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
|
||||||
|
return self._format_tool_message(
|
||||||
|
message,
|
||||||
|
tool_call_id,
|
||||||
|
status="error",
|
||||||
|
artifact={
|
||||||
|
"timed_out": True,
|
||||||
|
"exit_code": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sanitized_output, matches = self._apply_redactions(result.output)
|
||||||
|
except PIIDetectionError as error:
|
||||||
|
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
|
||||||
|
message = f"Output blocked: detected {error.pii_type}."
|
||||||
|
return self._format_tool_message(
|
||||||
|
message,
|
||||||
|
tool_call_id,
|
||||||
|
status="error",
|
||||||
|
artifact={
|
||||||
|
"timed_out": False,
|
||||||
|
"exit_code": result.exit_code,
|
||||||
|
"matches": {error.pii_type: error.matches},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
sanitized_output = sanitized_output or "<no output>"
|
||||||
|
if result.truncated_by_lines:
|
||||||
|
sanitized_output = (
|
||||||
|
f"{sanitized_output.rstrip()}\n\n"
|
||||||
|
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
|
||||||
|
f"(observed {result.total_lines})."
|
||||||
|
)
|
||||||
|
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
|
||||||
|
sanitized_output = (
|
||||||
|
f"{sanitized_output.rstrip()}\n\n"
|
||||||
|
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
|
||||||
|
f"(observed {result.total_bytes})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.exit_code not in (0, None):
|
||||||
|
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
||||||
|
final_status: Literal["success", "error"] = "error"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
artifact = {
|
||||||
|
"timed_out": False,
|
||||||
|
"exit_code": result.exit_code,
|
||||||
|
"truncated_by_lines": result.truncated_by_lines,
|
||||||
|
"truncated_by_bytes": result.truncated_by_bytes,
|
||||||
|
"total_lines": result.total_lines,
|
||||||
|
"total_bytes": result.total_bytes,
|
||||||
|
"redaction_matches": matches,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._format_tool_message(
|
||||||
|
sanitized_output,
|
||||||
|
tool_call_id,
|
||||||
|
status=final_status,
|
||||||
|
artifact=artifact,
|
||||||
|
)
|
||||||
|
|
||||||
|
def wrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
||||||
|
) -> ToolMessage | Command:
|
||||||
|
"""Intercept local shell tool calls and execute them via the managed session."""
|
||||||
|
if isinstance(request.tool, _PersistentShellTool):
|
||||||
|
resources = self._ensure_resources(request.state)
|
||||||
|
return self._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
request.tool_call["args"],
|
||||||
|
tool_call_id=request.tool_call.get("id"),
|
||||||
|
)
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
||||||
|
) -> ToolMessage | Command:
|
||||||
|
"""Async interception mirroring the synchronous tool handler."""
|
||||||
|
if isinstance(request.tool, _PersistentShellTool):
|
||||||
|
resources = self._ensure_resources(request.state)
|
||||||
|
return self._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
request.tool_call["args"],
|
||||||
|
tool_call_id=request.tool_call.get("id"),
|
||||||
|
)
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
def _format_tool_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
tool_call_id: str | None,
|
||||||
|
*,
|
||||||
|
status: Literal["success", "error"],
|
||||||
|
artifact: dict[str, Any] | None = None,
|
||||||
|
) -> ToolMessage | str:
|
||||||
|
artifact = artifact or {}
|
||||||
|
if tool_call_id is None:
|
||||||
|
return content
|
||||||
|
return ToolMessage(
|
||||||
|
content=content,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=self._tool.name,
|
||||||
|
status=status,
|
||||||
|
artifact=artifact,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CodexSandboxExecutionPolicy",
|
||||||
|
"DockerExecutionPolicy",
|
||||||
|
"HostExecutionPolicy",
|
||||||
|
"RedactionRule",
|
||||||
|
"ShellToolMiddleware",
|
||||||
|
]
|
||||||
@@ -0,0 +1,404 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents.middleware.shell_tool import (
|
||||||
|
HostExecutionPolicy,
|
||||||
|
CodexSandboxExecutionPolicy,
|
||||||
|
DockerExecutionPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.agents.middleware import _execution
|
||||||
|
|
||||||
|
|
||||||
|
def _make_resource(
|
||||||
|
*,
|
||||||
|
with_prlimit: bool,
|
||||||
|
has_rlimit_as: bool = True,
|
||||||
|
) -> Any:
|
||||||
|
"""Create a fake ``resource`` module for testing."""
|
||||||
|
|
||||||
|
class _BaseResource:
|
||||||
|
RLIMIT_CPU = 0
|
||||||
|
RLIMIT_DATA = 2
|
||||||
|
|
||||||
|
if has_rlimit_as:
|
||||||
|
RLIMIT_AS = 1
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.prlimit_calls: list[tuple[int, int, tuple[int, int]]] = []
|
||||||
|
self.setrlimit_calls: list[tuple[int, tuple[int, int]]] = []
|
||||||
|
|
||||||
|
def setrlimit(self, resource_name: int, limits: tuple[int, int]) -> None:
|
||||||
|
self.setrlimit_calls.append((resource_name, limits))
|
||||||
|
|
||||||
|
if with_prlimit:
|
||||||
|
|
||||||
|
class _Resource(_BaseResource):
|
||||||
|
def prlimit(self, pid: int, resource_name: int, limits: tuple[int, int]) -> None:
|
||||||
|
self.prlimit_calls.append((pid, resource_name, limits))
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
class _Resource(_BaseResource):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return _Resource()
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_validations() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
HostExecutionPolicy(max_output_lines=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
HostExecutionPolicy(cpu_time_seconds=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
HostExecutionPolicy(memory_bytes=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_requires_resource_for_limits(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution, "resource", None, raising=False)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
HostExecutionPolicy(cpu_time_seconds=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_applies_prlimit(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
fake_resource = _make_resource(with_prlimit=True)
|
||||||
|
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "linux")
|
||||||
|
|
||||||
|
recorded: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pid = 1234
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["command"] = list(command)
|
||||||
|
recorded["env"] = dict(env)
|
||||||
|
recorded["cwd"] = cwd
|
||||||
|
recorded["preexec_fn"] = preexec_fn
|
||||||
|
recorded["start_new_session"] = start_new_session
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
policy = HostExecutionPolicy(cpu_time_seconds=2, memory_bytes=4096)
|
||||||
|
env = {"PATH": os.environ.get("PATH", ""), "VAR": "1"}
|
||||||
|
process = policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
|
||||||
|
|
||||||
|
assert process is not None
|
||||||
|
assert recorded["preexec_fn"] is None
|
||||||
|
assert recorded["start_new_session"] is True
|
||||||
|
assert fake_resource.prlimit_calls == [
|
||||||
|
(1234, fake_resource.RLIMIT_CPU, (2, 2)),
|
||||||
|
(1234, fake_resource.RLIMIT_AS, (4096, 4096)),
|
||||||
|
]
|
||||||
|
assert fake_resource.setrlimit_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_uses_preexec_on_macos(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
fake_resource = _make_resource(with_prlimit=False)
|
||||||
|
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "darwin")
|
||||||
|
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pid = 4321
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
captured["preexec_fn"] = preexec_fn
|
||||||
|
captured["start_new_session"] = start_new_session
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
policy = HostExecutionPolicy(cpu_time_seconds=5, memory_bytes=8192)
|
||||||
|
env = {"PATH": os.environ.get("PATH", "")}
|
||||||
|
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
|
||||||
|
|
||||||
|
preexec_fn = captured["preexec_fn"]
|
||||||
|
assert callable(preexec_fn)
|
||||||
|
assert captured["start_new_session"] is True
|
||||||
|
|
||||||
|
preexec_fn()
|
||||||
|
# macOS fallback should use setrlimit
|
||||||
|
assert fake_resource.setrlimit_calls == [
|
||||||
|
(fake_resource.RLIMIT_CPU, (5, 5)),
|
||||||
|
(fake_resource.RLIMIT_AS, (8192, 8192)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_respects_process_group_flag(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
fake_resource = _make_resource(with_prlimit=True)
|
||||||
|
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "linux")
|
||||||
|
|
||||||
|
recorded: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pid = 1111
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["start_new_session"] = start_new_session
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
policy = HostExecutionPolicy(create_process_group=False)
|
||||||
|
env = {"PATH": os.environ.get("PATH", "")}
|
||||||
|
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
|
||||||
|
|
||||||
|
assert recorded["start_new_session"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_policy_falls_back_to_rlimit_data(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
fake_resource = _make_resource(with_prlimit=True, has_rlimit_as=False)
|
||||||
|
monkeypatch.setattr(_execution, "resource", fake_resource, raising=False)
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "linux")
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pid = 2222
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
policy = HostExecutionPolicy(cpu_time_seconds=7, memory_bytes=2048)
|
||||||
|
env = {"PATH": os.environ.get("PATH", "")}
|
||||||
|
policy.spawn(workspace=tmp_path, env=env, command=("/bin/sh",))
|
||||||
|
|
||||||
|
assert fake_resource.prlimit_calls == [
|
||||||
|
(2222, fake_resource.RLIMIT_CPU, (7, 7)),
|
||||||
|
(2222, fake_resource.RLIMIT_DATA, (2048, 2048)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
shutil.which("codex") is None,
|
||||||
|
reason="codex CLI not available on PATH",
|
||||||
|
)
|
||||||
|
def test_codex_policy_spawns_codex_cli(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
recorded: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["command"] = list(command)
|
||||||
|
assert cwd == tmp_path
|
||||||
|
assert env["TEST_VAR"] == "1"
|
||||||
|
assert preexec_fn is None
|
||||||
|
assert not start_new_session
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"langchain.agents.middleware._execution._launch_subprocess",
|
||||||
|
fake_launch,
|
||||||
|
)
|
||||||
|
policy = CodexSandboxExecutionPolicy(
|
||||||
|
platform="linux",
|
||||||
|
config_overrides={"sandbox_permissions": ["disk-full-read-access"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
env = {"TEST_VAR": "1"}
|
||||||
|
policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",))
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
shutil.which("codex"),
|
||||||
|
"sandbox",
|
||||||
|
"linux",
|
||||||
|
"-c",
|
||||||
|
'sandbox_permissions=["disk-full-read-access"]',
|
||||||
|
"--",
|
||||||
|
"/bin/bash",
|
||||||
|
]
|
||||||
|
assert recorded["command"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_auto_platform_linux(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "linux")
|
||||||
|
policy = CodexSandboxExecutionPolicy(platform="auto")
|
||||||
|
assert policy._determine_platform() == "linux"
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_auto_platform_macos(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "darwin")
|
||||||
|
policy = CodexSandboxExecutionPolicy(platform="auto")
|
||||||
|
assert policy._determine_platform() == "macos"
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.shutil, "which", lambda _: None)
|
||||||
|
policy = CodexSandboxExecutionPolicy(binary="codex")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
policy._resolve_binary()
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_auto_platform_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.sys, "platform", "win32")
|
||||||
|
policy = CodexSandboxExecutionPolicy(platform="auto")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
policy._determine_platform()
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_formats_override_values() -> None:
|
||||||
|
policy = CodexSandboxExecutionPolicy()
|
||||||
|
assert policy._format_override({"a": 1}) == '{"a": 1}'
|
||||||
|
|
||||||
|
class Custom:
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "custom"
|
||||||
|
|
||||||
|
assert policy._format_override(Custom()) == "custom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_policy_sorts_config_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/codex")
|
||||||
|
policy = CodexSandboxExecutionPolicy(
|
||||||
|
config_overrides={"b": 2, "a": 1},
|
||||||
|
platform="linux",
|
||||||
|
)
|
||||||
|
command = policy._build_command(("echo",))
|
||||||
|
indices = [i for i, part in enumerate(command) if part == "-c"]
|
||||||
|
override_values = [command[i + 1] for i in indices]
|
||||||
|
assert override_values == ["a=1", "b=2"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
shutil.which("docker") is None,
|
||||||
|
reason="docker CLI not available on PATH",
|
||||||
|
)
|
||||||
|
def test_docker_policy_spawns_docker_run(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
recorded: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["command"] = list(command)
|
||||||
|
assert cwd == tmp_path
|
||||||
|
assert "PATH" in env # host environment should retain system PATH
|
||||||
|
assert not start_new_session
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"langchain.agents.middleware._execution._launch_subprocess",
|
||||||
|
fake_launch,
|
||||||
|
)
|
||||||
|
policy = DockerExecutionPolicy(
|
||||||
|
image="ubuntu:22.04",
|
||||||
|
memory_bytes=4096,
|
||||||
|
extra_run_args=("--ipc", "host"),
|
||||||
|
)
|
||||||
|
|
||||||
|
env = {"PATH": "/bin"}
|
||||||
|
policy.spawn(workspace=tmp_path, env=env, command=("/bin/bash",))
|
||||||
|
|
||||||
|
command = recorded["command"]
|
||||||
|
assert command[0] == shutil.which("docker")
|
||||||
|
assert command[1:4] == ["run", "-i", "--rm"]
|
||||||
|
assert "--memory" in command
|
||||||
|
assert "4096" in command
|
||||||
|
assert "-v" in command and any(str(tmp_path) in part for part in command)
|
||||||
|
assert "-w" in command
|
||||||
|
w_index = command.index("-w")
|
||||||
|
assert command[w_index + 1] == str(tmp_path)
|
||||||
|
assert "-e" in command and "PATH=/bin" in command
|
||||||
|
assert command[-2:] == ["ubuntu:22.04", "/bin/bash"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_rejects_cpu_limit() -> None:
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
DockerExecutionPolicy(cpu_time_seconds=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_validates_memory() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DockerExecutionPolicy(memory_bytes=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_skips_mount_for_temp_workspace(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker")
|
||||||
|
|
||||||
|
recorded: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["command"] = list(command)
|
||||||
|
assert cwd == workspace
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
workspace = tmp_path / f"{_execution.SHELL_TEMP_PREFIX}case"
|
||||||
|
workspace.mkdir()
|
||||||
|
policy = DockerExecutionPolicy(cpus="1.5")
|
||||||
|
env = {"PATH": "/bin"}
|
||||||
|
policy.spawn(workspace=workspace, env=env, command=("/bin/sh",))
|
||||||
|
|
||||||
|
command = recorded["command"]
|
||||||
|
assert "-v" not in command
|
||||||
|
assert "-w" in command
|
||||||
|
w_index = command.index("-w")
|
||||||
|
assert command[w_index + 1] == "/"
|
||||||
|
assert "--cpus" in command
|
||||||
|
assert "--network" in command and "none" in command
|
||||||
|
assert command[-2:] == [policy.image, "/bin/sh"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_validates_cpus() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DockerExecutionPolicy(cpus=" ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_validates_user() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DockerExecutionPolicy(user=" ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_read_only_and_user(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
monkeypatch.setattr(_execution.shutil, "which", lambda _: "/usr/bin/docker")
|
||||||
|
|
||||||
|
recorded: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fake_launch(command, *, env, cwd, preexec_fn, start_new_session): # noqa: ANN001
|
||||||
|
recorded["command"] = list(command)
|
||||||
|
return DummyProcess()
|
||||||
|
|
||||||
|
monkeypatch.setattr(_execution, "_launch_subprocess", fake_launch)
|
||||||
|
|
||||||
|
workspace = tmp_path
|
||||||
|
policy = DockerExecutionPolicy(read_only_rootfs=True, user="1000:1000")
|
||||||
|
policy.spawn(workspace=workspace, env={"PATH": "/bin"}, command=("/bin/sh",))
|
||||||
|
|
||||||
|
command = recorded["command"]
|
||||||
|
assert "--read-only" in command
|
||||||
|
assert "--user" in command
|
||||||
|
user_index = command.index("--user")
|
||||||
|
assert command[user_index + 1] == "1000:1000"
|
||||||
|
|
||||||
|
|
||||||
|
def test_docker_policy_resolve_missing_binary(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(_execution.shutil, "which", lambda _: None)
|
||||||
|
policy = DockerExecutionPolicy()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
policy._resolve_binary()
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents.middleware.shell_tool import (
|
||||||
|
HostExecutionPolicy,
|
||||||
|
ShellToolMiddleware,
|
||||||
|
_SessionResources,
|
||||||
|
RedactionRule,
|
||||||
|
)
|
||||||
|
from langchain.agents.middleware.types import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_state() -> AgentState:
|
||||||
|
return {"messages": []} # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
def test_executes_command_and_persists_state(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
|
||||||
|
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert result.strip() == "/"
|
||||||
|
echo_result = middleware._run_shell_tool(
|
||||||
|
resources, {"command": "echo ready"}, tool_call_id=None
|
||||||
|
)
|
||||||
|
assert "ready" in echo_result
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_restart_resets_session_environment(tmp_path: Path) -> None:
|
||||||
|
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
|
||||||
|
restart_message = middleware._run_shell_tool(
|
||||||
|
resources, {"restart": True}, tool_call_id=None
|
||||||
|
)
|
||||||
|
assert "restarted" in restart_message.lower()
|
||||||
|
resources = middleware._ensure_resources(state) # reacquire after restart
|
||||||
|
result = middleware._run_shell_tool(
|
||||||
|
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
|
||||||
|
)
|
||||||
|
assert "unset" in result
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncation_indicator_present(tmp_path: Path) -> None:
|
||||||
|
policy = HostExecutionPolicy(max_output_lines=5, command_timeout=5.0)
|
||||||
|
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||||
|
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
|
||||||
|
assert "Output truncated" in result
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_returns_error(tmp_path: Path) -> None:
|
||||||
|
policy = HostExecutionPolicy(command_timeout=0.5)
|
||||||
|
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||||
|
start = time.monotonic()
|
||||||
|
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
assert elapsed < policy.command_timeout + 2.0
|
||||||
|
assert "timed out" in result.lower()
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redaction_policy_applies(tmp_path: Path) -> None:
|
||||||
|
middleware = ShellToolMiddleware(
|
||||||
|
workspace_root=tmp_path / "workspace",
|
||||||
|
redaction_rules=(RedactionRule(pii_type="email", strategy="redact"),),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||||
|
message = middleware._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
{"command": "printf 'Contact: user@example.com\\n'"},
|
||||||
|
tool_call_id=None,
|
||||||
|
)
|
||||||
|
assert "[REDACTED_EMAIL]" in message
|
||||||
|
assert "user@example.com" not in message
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_startup_and_shutdown_commands(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
middleware = ShellToolMiddleware(
|
||||||
|
workspace_root=workspace,
|
||||||
|
startup_commands=("touch startup.txt",),
|
||||||
|
shutdown_commands=("touch shutdown.txt",),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
state: AgentState = _empty_state()
|
||||||
|
updates = middleware.before_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
assert (workspace / "startup.txt").exists()
|
||||||
|
finally:
|
||||||
|
updates = middleware.after_agent(state, None)
|
||||||
|
if updates:
|
||||||
|
state.update(updates)
|
||||||
|
assert (workspace / "shutdown.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
|
||||||
|
policy = HostExecutionPolicy(termination_timeout=0.1)
|
||||||
|
|
||||||
|
class DummySession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.stopped: bool = False
|
||||||
|
|
||||||
|
def stop(self, timeout: float) -> None: # noqa: ARG002
|
||||||
|
self.stopped = True
|
||||||
|
|
||||||
|
session = DummySession()
|
||||||
|
tempdir = tempfile.TemporaryDirectory(dir=tmp_path)
|
||||||
|
tempdir_path = Path(tempdir.name)
|
||||||
|
resources = _SessionResources(session=session, tempdir=tempdir, policy=policy) # type: ignore[arg-type]
|
||||||
|
finalizer = resources._finalizer
|
||||||
|
|
||||||
|
# Drop our last strong reference and force collection.
|
||||||
|
del resources
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
assert not finalizer.alive
|
||||||
|
assert session.stopped
|
||||||
|
assert not tempdir_path.exists()
|
||||||
@@ -6,6 +6,7 @@ from langchain_anthropic.middleware.anthropic_tools import (
|
|||||||
StateClaudeMemoryMiddleware,
|
StateClaudeMemoryMiddleware,
|
||||||
StateClaudeTextEditorMiddleware,
|
StateClaudeTextEditorMiddleware,
|
||||||
)
|
)
|
||||||
|
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
||||||
from langchain_anthropic.middleware.file_search import (
|
from langchain_anthropic.middleware.file_search import (
|
||||||
StateFileSearchMiddleware,
|
StateFileSearchMiddleware,
|
||||||
)
|
)
|
||||||
@@ -15,6 +16,7 @@ from langchain_anthropic.middleware.prompt_caching import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicPromptCachingMiddleware",
|
"AnthropicPromptCachingMiddleware",
|
||||||
|
"ClaudeBashToolMiddleware",
|
||||||
"FilesystemClaudeMemoryMiddleware",
|
"FilesystemClaudeMemoryMiddleware",
|
||||||
"FilesystemClaudeTextEditorMiddleware",
|
"FilesystemClaudeTextEditorMiddleware",
|
||||||
"StateClaudeMemoryMiddleware",
|
"StateClaudeMemoryMiddleware",
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Anthropic-specific middleware for the Claude bash tool."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
|
from langchain.tools.tool_node import ToolCallRequest
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||||
|
"""Middleware that exposes Anthropic's native bash tool to models."""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Initialize middleware without registering a client-side tool."""
|
||||||
|
kwargs["shell_command"] = ("/bin/bash",)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Remove the base tool so Claude's native descriptor is the sole entry.
|
||||||
|
self._tool = None # type: ignore[assignment]
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def wrap_model_call(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Ensure the Claude bash descriptor is available to the model."""
|
||||||
|
tools = request.tools
|
||||||
|
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||||
|
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||||
|
request = request.override(tools=tools)
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
|
def wrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Command | ToolMessage],
|
||||||
|
) -> Command | ToolMessage:
|
||||||
|
"""Intercept Claude bash tool calls and execute them locally."""
|
||||||
|
tool_call = request.tool_call
|
||||||
|
if tool_call.get("name") != "bash":
|
||||||
|
return handler(request)
|
||||||
|
resources = self._ensure_resources(request.state)
|
||||||
|
return self._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
tool_call["args"],
|
||||||
|
tool_call_id=tool_call.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def awrap_tool_call(
|
||||||
|
self,
|
||||||
|
request: ToolCallRequest,
|
||||||
|
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
|
||||||
|
) -> Command | ToolMessage:
|
||||||
|
"""Async interception mirroring the synchronous implementation."""
|
||||||
|
tool_call = request.tool_call
|
||||||
|
if tool_call.get("name") != "bash":
|
||||||
|
return await handler(request)
|
||||||
|
resources = self._ensure_resources(request.state)
|
||||||
|
return self._run_shell_tool(
|
||||||
|
resources,
|
||||||
|
tool_call["args"],
|
||||||
|
tool_call_id=tool_call.get("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_tool_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
tool_call_id: str | None,
|
||||||
|
*,
|
||||||
|
status: Literal["success", "error"],
|
||||||
|
artifact: dict[str, Any] | None = None,
|
||||||
|
) -> ToolMessage | str:
|
||||||
|
"""Format tool responses using Claude's bash descriptor."""
|
||||||
|
if tool_call_id is None:
|
||||||
|
return content
|
||||||
|
return ToolMessage(
|
||||||
|
content=content,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=_CLAUDE_BASH_DESCRIPTOR["name"],
|
||||||
|
status=status,
|
||||||
|
artifact=artifact or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ClaudeBashToolMiddleware"]
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages.tool import ToolCall
|
||||||
|
|
||||||
|
pytest.importorskip(
|
||||||
|
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.tools.tool_node import ToolCallRequest
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
|
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
middleware = ClaudeBashToolMiddleware()
|
||||||
|
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||||
|
|
||||||
|
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call: ToolCall = {
|
||||||
|
"name": "bash",
|
||||||
|
"args": {"command": "echo hi"},
|
||||||
|
"id": "call-1",
|
||||||
|
}
|
||||||
|
request = ToolCallRequest(
|
||||||
|
tool_call=tool_call,
|
||||||
|
tool=MagicMock(),
|
||||||
|
state={},
|
||||||
|
runtime=None, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
handler_called = False
|
||||||
|
|
||||||
|
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||||
|
nonlocal handler_called
|
||||||
|
handler_called = True
|
||||||
|
return ToolMessage(content="should not be used", tool_call_id="call-1")
|
||||||
|
|
||||||
|
result = middleware.wrap_tool_call(request, handler)
|
||||||
|
assert result is sentinel
|
||||||
|
assert handler_called is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrap_tool_call_passes_through_other_tools(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
middleware = ClaudeBashToolMiddleware()
|
||||||
|
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
|
||||||
|
request = ToolCallRequest(
|
||||||
|
tool_call=tool_call,
|
||||||
|
tool=MagicMock(),
|
||||||
|
state={},
|
||||||
|
runtime=None, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
|
||||||
|
|
||||||
|
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||||
|
return sentinel
|
||||||
|
|
||||||
|
result = middleware.wrap_tool_call(request, handler)
|
||||||
|
assert result is sentinel
|
||||||
Reference in New Issue
Block a user