mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
1 Commits
sr/typing-
...
langchain=
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee3fc91e7a |
@@ -4,6 +4,7 @@ from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
)
|
||||
from .file_search import FilesystemFileSearchMiddleware
|
||||
from .human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
@@ -46,6 +47,7 @@ __all__ = [
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"DockerExecutionPolicy",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InterruptOnConfig",
|
||||
|
||||
@@ -353,3 +353,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -198,6 +198,29 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Increment model call counts after a model call.
|
||||
|
||||
@@ -212,3 +235,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ModelCallLimitState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -252,6 +252,27 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
def after_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
@@ -311,6 +332,26 @@ class PIIMiddleware(AgentMiddleware):
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
|
||||
@@ -11,7 +11,6 @@ import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
@@ -19,9 +18,10 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import BaseTool, ToolException
|
||||
from langchain_core.tools.base import ToolException
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain.agents.middleware._execution import (
|
||||
@@ -38,14 +38,13 @@ from langchain.agents.middleware._redaction import (
|
||||
ResolvedRedactionRule,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||
@@ -59,6 +58,7 @@ DEFAULT_TOOL_DESCRIPTION = (
|
||||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||
"running commands will be terminated once their configured timeout elapses."
|
||||
)
|
||||
SHELL_TOOL_NAME = "shell"
|
||||
|
||||
|
||||
def _cleanup_resources(
|
||||
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
|
||||
"""Input schema for the persistent shell tool."""
|
||||
|
||||
command: str | None = None
|
||||
"""The shell command to execute."""
|
||||
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
injected ToolRuntime.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self) -> _ShellToolInput:
|
||||
@@ -347,24 +357,6 @@ class _ShellToolInput(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class _PersistentShellTool(BaseTool):
|
||||
"""Tool wrapper that relies on middleware interception for execution."""
|
||||
|
||||
name: str = "shell"
|
||||
description: str = DEFAULT_TOOL_DESCRIPTION
|
||||
args_schema: type[BaseModel] = _ShellToolInput
|
||||
|
||||
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
|
||||
super().__init__()
|
||||
self._middleware = middleware
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
|
||||
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""Middleware that registers a persistent shell tool for agents.
|
||||
|
||||
@@ -393,6 +385,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
execution_policy: BaseExecutionPolicy | None = None,
|
||||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||
tool_description: str | None = None,
|
||||
tool_name: str = SHELL_TOOL_NAME,
|
||||
shell_command: Sequence[str] | str | None = None,
|
||||
env: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
@@ -414,6 +407,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
returning it to the model.
|
||||
tool_description: Optional override for the registered shell tool
|
||||
description.
|
||||
tool_name: Name for the registered shell tool.
|
||||
|
||||
Defaults to `"shell"`.
|
||||
shell_command: Optional shell executable (string) or argument sequence used
|
||||
to launch the persistent session.
|
||||
|
||||
@@ -425,6 +421,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""
|
||||
super().__init__()
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
self._tool_name = tool_name
|
||||
self._shell_command = self._normalize_shell_command(shell_command)
|
||||
self._environment = self._normalize_env(env)
|
||||
if execution_policy is not None:
|
||||
@@ -438,9 +435,25 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
self._startup_commands = self._normalize_commands(startup_commands)
|
||||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||
|
||||
# Create a proper tool that executes directly (no interception needed)
|
||||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||
self._tool = _PersistentShellTool(self, description=description)
|
||||
self.tools = [self._tool]
|
||||
|
||||
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||||
def shell_tool(
|
||||
*,
|
||||
runtime: ToolRuntime[None, ShellToolState],
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._ensure_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
self._shell_tool = shell_tool
|
||||
self.tools = [self._shell_tool]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_commands(
|
||||
@@ -482,7 +495,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Async counterpart to `before_agent`."""
|
||||
"""Async start the shell session and run startup commands."""
|
||||
return self.before_agent(state, runtime)
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
@@ -494,7 +507,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
resources._finalizer()
|
||||
|
||||
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||||
"""Async counterpart to `after_agent`."""
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
@@ -669,36 +682,6 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept local shell tool calls and execute them via the managed session."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Async interception mirroring the synchronous tool handler."""
|
||||
if isinstance(request.tool, _PersistentShellTool):
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
request.tool_call["args"],
|
||||
tool_call_id=request.tool_call.get("id"),
|
||||
)
|
||||
return await handler(request)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
@@ -713,7 +696,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=self._tool.name,
|
||||
name=self._tool_name,
|
||||
status=status,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
@@ -451,3 +451,28 @@ class ToolCallLimitMiddleware(
|
||||
"run_tool_call_count": run_counts,
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@@ -9,7 +9,7 @@ license = { text = "MIT" }
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.5"
|
||||
version = "1.0.6"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"langchain-core>=1.0.4,<2.0.0",
|
||||
|
||||
2
libs/langchain_v1/uv.lock
generated
2
libs/langchain_v1/uv.lock
generated
@@ -1788,7 +1788,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "1.0.5"
|
||||
version = "1.0.6"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
|
||||
@@ -17,7 +17,9 @@ from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
_ModelRequestOverrides,
|
||||
)
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@@ -25,7 +27,6 @@ from typing_extensions import NotRequired, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
||||
# Tool type constants
|
||||
TEXT_EDITOR_TOOL_TYPE = "text_editor_20250728"
|
||||
@@ -184,149 +185,127 @@ class _StateClaudeFileToolMiddleware(AgentMiddleware):
|
||||
self.allowed_prefixes = allowed_path_prefixes
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# Create tool that will be executed by the tool node
|
||||
@tool(tool_name)
|
||||
def file_tool(
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
command: str,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
new_path: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
) -> Command | str:
|
||||
"""Execute file operations on virtual file system.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime providing access to state.
|
||||
command: Operation to perform.
|
||||
path: File path to operate on.
|
||||
file_text: Full file content for create command.
|
||||
old_str: String to replace for str_replace command.
|
||||
new_str: Replacement string for str_replace command.
|
||||
insert_line: Line number for insert command.
|
||||
new_path: New path for rename command.
|
||||
view_range: Line range [start, end] for view command.
|
||||
|
||||
Returns:
|
||||
Command for state update or string result.
|
||||
"""
|
||||
# Build args dict for handler methods
|
||||
args: dict[str, Any] = {"path": path}
|
||||
if file_text is not None:
|
||||
args["file_text"] = file_text
|
||||
if old_str is not None:
|
||||
args["old_str"] = old_str
|
||||
if new_str is not None:
|
||||
args["new_str"] = new_str
|
||||
if insert_line is not None:
|
||||
args["insert_line"] = insert_line
|
||||
if new_path is not None:
|
||||
args["new_path"] = new_path
|
||||
if view_range is not None:
|
||||
args["view_range"] = view_range
|
||||
|
||||
# Route to appropriate handler based on command
|
||||
try:
|
||||
if command == "view":
|
||||
return self._handle_view(args, runtime.state, runtime.tool_call_id)
|
||||
if command == "create":
|
||||
return self._handle_create(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "insert":
|
||||
return self._handle_insert(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "delete":
|
||||
return self._handle_delete(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
if command == "rename":
|
||||
return self._handle_rename(
|
||||
args, runtime.state, runtime.tool_call_id
|
||||
)
|
||||
return f"Unknown command: {command}"
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return str(e)
|
||||
|
||||
self.tools = [file_tool]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.override(**overrides))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt (async version)."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
state = request.state
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, state, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, state, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, state, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, state, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, state, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, state, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls (async version)."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return await handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
state = request.state
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, state, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, state, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, state, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, state, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, state, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, state, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
return await handler(request.override(**overrides))
|
||||
|
||||
def _handle_view(
|
||||
self, args: dict, state: AnthropicToolsState, tool_call_id: str | None
|
||||
@@ -692,146 +671,117 @@ class _FilesystemClaudeFileToolMiddleware(AgentMiddleware):
|
||||
# Create root directory if it doesn't exist
|
||||
self.root_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create tool that will be executed by the tool node
|
||||
@tool(tool_name)
|
||||
def file_tool(
|
||||
runtime: ToolRuntime,
|
||||
command: str,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
new_path: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
) -> Command | str:
|
||||
"""Execute file operations on filesystem.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime providing tool_call_id.
|
||||
command: Operation to perform.
|
||||
path: File path to operate on.
|
||||
file_text: Full file content for create command.
|
||||
old_str: String to replace for str_replace command.
|
||||
new_str: Replacement string for str_replace command.
|
||||
insert_line: Line number for insert command.
|
||||
new_path: New path for rename command.
|
||||
view_range: Line range [start, end] for view command.
|
||||
|
||||
Returns:
|
||||
Command for message update or string result.
|
||||
"""
|
||||
# Build args dict for handler methods
|
||||
args: dict[str, Any] = {"path": path}
|
||||
if file_text is not None:
|
||||
args["file_text"] = file_text
|
||||
if old_str is not None:
|
||||
args["old_str"] = old_str
|
||||
if new_str is not None:
|
||||
args["new_str"] = new_str
|
||||
if insert_line is not None:
|
||||
args["insert_line"] = insert_line
|
||||
if new_path is not None:
|
||||
args["new_path"] = new_path
|
||||
if view_range is not None:
|
||||
args["view_range"] = view_range
|
||||
|
||||
# Route to appropriate handler based on command
|
||||
try:
|
||||
if command == "view":
|
||||
return self._handle_view(args, runtime.tool_call_id)
|
||||
if command == "create":
|
||||
return self._handle_create(args, runtime.tool_call_id)
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, runtime.tool_call_id)
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, runtime.tool_call_id)
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, runtime.tool_call_id)
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, runtime.tool_call_id)
|
||||
return f"Unknown command: {command}"
|
||||
except (ValueError, FileNotFoundError, PermissionError) as e:
|
||||
return str(e)
|
||||
|
||||
self.tools = [file_tool]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
|
||||
return handler(request.override(**overrides))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Inject tool and optional system prompt (async version)."""
|
||||
# Add tool
|
||||
tools = list(request.tools or [])
|
||||
tools.append(
|
||||
{
|
||||
"type": self.tool_type,
|
||||
"name": self.tool_name,
|
||||
}
|
||||
)
|
||||
request.tools = tools
|
||||
"""Inject Anthropic tool descriptor and optional system prompt."""
|
||||
# Replace our BaseTool with Anthropic's native tool descriptor
|
||||
tools = [
|
||||
t
|
||||
for t in (request.tools or [])
|
||||
if getattr(t, "name", None) != self.tool_name
|
||||
] + [{"type": self.tool_type, "name": self.tool_name}]
|
||||
|
||||
# Inject system prompt if provided
|
||||
overrides: _ModelRequestOverrides = {"tools": tools}
|
||||
if self.system_prompt:
|
||||
request.system_prompt = (
|
||||
overrides["system_prompt"] = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool calls (async version)."""
|
||||
tool_call = request.tool_call
|
||||
tool_name = tool_call.get("name")
|
||||
|
||||
if tool_name != self.tool_name:
|
||||
return await handler(request)
|
||||
|
||||
# Handle tool call
|
||||
try:
|
||||
args = tool_call.get("args", {})
|
||||
command = args.get("command")
|
||||
|
||||
if command == "view":
|
||||
return self._handle_view(args, tool_call["id"])
|
||||
if command == "create":
|
||||
return self._handle_create(args, tool_call["id"])
|
||||
if command == "str_replace":
|
||||
return self._handle_str_replace(args, tool_call["id"])
|
||||
if command == "insert":
|
||||
return self._handle_insert(args, tool_call["id"])
|
||||
if command == "delete":
|
||||
return self._handle_delete(args, tool_call["id"])
|
||||
if command == "rename":
|
||||
return self._handle_rename(args, tool_call["id"])
|
||||
|
||||
msg = f"Unknown command: {command}"
|
||||
return ToolMessage(
|
||||
content=msg,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
return await handler(request.override(**overrides))
|
||||
|
||||
def _validate_and_resolve_path(self, path: str) -> Path:
|
||||
"""Validate and resolve a virtual path to filesystem path.
|
||||
|
||||
@@ -3,93 +3,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
_CLAUDE_BASH_DESCRIPTOR = {"type": "bash_20250124", "name": "bash"}
|
||||
# Tool type constants for Anthropic
|
||||
BASH_TOOL_TYPE = "bash_20250124"
|
||||
BASH_TOOL_NAME = "bash"
|
||||
|
||||
|
||||
class ClaudeBashToolMiddleware(ShellToolMiddleware):
|
||||
"""Middleware that exposes Anthropic's native bash tool to models."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize middleware without registering a client-side tool."""
|
||||
kwargs["shell_command"] = ("/bin/bash",)
|
||||
super().__init__(*args, **kwargs)
|
||||
# Remove the base tool so Claude's native descriptor is the sole entry.
|
||||
self._tool = None # type: ignore[assignment]
|
||||
self.tools = []
|
||||
def __init__(
|
||||
self,
|
||||
workspace_root: str | None = None,
|
||||
*,
|
||||
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
execution_policy: Any | None = None,
|
||||
redaction_rules: tuple[Any, ...] | list[Any] | None = None,
|
||||
tool_description: str | None = None,
|
||||
env: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize middleware for Claude's native bash tool.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session.
|
||||
If omitted, a temporary directory is created.
|
||||
startup_commands: Optional commands executed after the session starts.
|
||||
shutdown_commands: Optional commands executed before session shutdown.
|
||||
execution_policy: Execution policy controlling timeouts and limits.
|
||||
redaction_rules: Optional redaction rules to sanitize output.
|
||||
tool_description: Optional override for tool description.
|
||||
env: Optional environment variables for the shell session.
|
||||
"""
|
||||
super().__init__(
|
||||
workspace_root=workspace_root,
|
||||
startup_commands=startup_commands,
|
||||
shutdown_commands=shutdown_commands,
|
||||
execution_policy=execution_policy,
|
||||
redaction_rules=redaction_rules,
|
||||
tool_description=tool_description,
|
||||
tool_name=BASH_TOOL_NAME,
|
||||
shell_command=("/bin/bash",),
|
||||
env=env,
|
||||
)
|
||||
# Parent class now creates the tool with name "bash" via tool_name parameter
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Ensure the Claude bash descriptor is available to the model."""
|
||||
tools = request.tools
|
||||
if all(tool is not _CLAUDE_BASH_DESCRIPTOR for tool in tools):
|
||||
tools = [*tools, _CLAUDE_BASH_DESCRIPTOR]
|
||||
request = request.override(tools=tools)
|
||||
return handler(request)
|
||||
"""Replace parent's shell tool with Claude's bash descriptor."""
|
||||
filtered = [
|
||||
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||
]
|
||||
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||
return handler(request.override(tools=tools))
|
||||
|
||||
def wrap_tool_call(
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Command | ToolMessage],
|
||||
) -> Command | ToolMessage:
|
||||
"""Intercept Claude bash tool calls and execute them locally."""
|
||||
tool_call = request.tool_call
|
||||
if tool_call.get("name") != "bash":
|
||||
return handler(request)
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
tool_call["args"],
|
||||
tool_call_id=tool_call.get("id"),
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[Command | ToolMessage]],
|
||||
) -> Command | ToolMessage:
|
||||
"""Async interception mirroring the synchronous implementation."""
|
||||
tool_call = request.tool_call
|
||||
if tool_call.get("name") != "bash":
|
||||
return await handler(request)
|
||||
resources = self._ensure_resources(request.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
tool_call["args"],
|
||||
tool_call_id=tool_call.get("id"),
|
||||
)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_call_id: str | None,
|
||||
*,
|
||||
status: Literal["success", "error"],
|
||||
artifact: dict[str, Any] | None = None,
|
||||
) -> ToolMessage | str:
|
||||
"""Format tool responses using Claude's bash descriptor."""
|
||||
if tool_call_id is None:
|
||||
return content
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=_CLAUDE_BASH_DESCRIPTOR["name"],
|
||||
status=status,
|
||||
artifact=artifact or {},
|
||||
)
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Async: replace parent's shell tool with Claude's bash descriptor."""
|
||||
filtered = [
|
||||
t for t in request.tools if getattr(t, "name", None) != BASH_TOOL_NAME
|
||||
]
|
||||
tools = [*filtered, {"type": BASH_TOOL_TYPE, "name": BASH_TOOL_NAME}]
|
||||
return await handler(request.override(tools=tools))
|
||||
|
||||
|
||||
__all__ = ["ClaudeBashToolMiddleware"]
|
||||
|
||||
@@ -9,13 +9,13 @@ from __future__ import annotations
|
||||
import fnmatch
|
||||
import re
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import TYPE_CHECKING, Annotated, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
from langchain_core.tools import InjectedToolArg, tool
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
from langchain_anthropic.middleware.anthropic_tools import AnthropicToolsState
|
||||
|
||||
@@ -128,9 +128,9 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
# Create tool instances
|
||||
@tool
|
||||
def glob_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
@@ -147,56 +147,17 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Get files from state
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
|
||||
# Match files
|
||||
matches = []
|
||||
for file_path, file_data in files.items():
|
||||
if file_path.startswith(base_path):
|
||||
# Get relative path from base
|
||||
if base_path == "/":
|
||||
relative = file_path[1:] # Remove leading /
|
||||
elif file_path == base_path:
|
||||
relative = Path(file_path).name
|
||||
elif file_path.startswith(base_path + "/"):
|
||||
relative = file_path[len(base_path) + 1 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Match against pattern
|
||||
# Handle ** pattern which requires special care
|
||||
# PurePosixPath.match doesn't match single-level paths
|
||||
# against **/pattern
|
||||
is_match = PurePosixPath(relative).match(pattern)
|
||||
if not is_match and pattern.startswith("**/"):
|
||||
# Also try matching without the **/ prefix for files in base dir
|
||||
is_match = PurePosixPath(relative).match(pattern[3:])
|
||||
|
||||
if is_match:
|
||||
matches.append((file_path, file_data["modified_at"]))
|
||||
|
||||
if not matches:
|
||||
return "No files found"
|
||||
|
||||
# Sort by modification time
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
file_paths = [path for path, _ in matches]
|
||||
|
||||
return "\n".join(file_paths)
|
||||
return self._handle_glob_search(pattern, path, runtime.state)
|
||||
|
||||
@tool
|
||||
def grep_search( # noqa: D417
|
||||
runtime: ToolRuntime[None, AnthropicToolsState],
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal[
|
||||
"files_with_matches", "content", "count"
|
||||
] = "files_with_matches",
|
||||
state: Annotated[AnthropicToolsState, InjectedToolArg] = None, # type: ignore[assignment]
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
@@ -216,49 +177,133 @@ class StateFileSearchMiddleware(AgentMiddleware):
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Search files
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
for file_path, file_data in files.items():
|
||||
if not file_path.startswith(base_path):
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include:
|
||||
basename = Path(file_path).name
|
||||
if not _match_include_pattern(basename, include):
|
||||
continue
|
||||
|
||||
# Search file content
|
||||
for line_num, line in enumerate(file_data["content"], 1):
|
||||
if regex.search(line):
|
||||
if file_path not in results:
|
||||
results[file_path] = []
|
||||
results[file_path].append((line_num, line))
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
return self._handle_grep_search(
|
||||
pattern, path, include, output_mode, runtime.state
|
||||
)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _handle_glob_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Handle glob search operation.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns "No files found" if no
|
||||
matches.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Get files from state
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
|
||||
# Match files
|
||||
matches = []
|
||||
for file_path, file_data in files.items():
|
||||
if file_path.startswith(base_path):
|
||||
# Get relative path from base
|
||||
if base_path == "/":
|
||||
relative = file_path[1:] # Remove leading /
|
||||
elif file_path == base_path:
|
||||
relative = Path(file_path).name
|
||||
elif file_path.startswith(base_path + "/"):
|
||||
relative = file_path[len(base_path) + 1 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Match against pattern
|
||||
# Handle ** pattern which requires special care
|
||||
# PurePosixPath.match doesn't match single-level paths
|
||||
# against **/pattern
|
||||
is_match = PurePosixPath(relative).match(pattern)
|
||||
if not is_match and pattern.startswith("**/"):
|
||||
# Also try matching without the **/ prefix for files in base dir
|
||||
is_match = PurePosixPath(relative).match(pattern[3:])
|
||||
|
||||
if is_match:
|
||||
matches.append((file_path, file_data["modified_at"]))
|
||||
|
||||
if not matches:
|
||||
return "No files found"
|
||||
|
||||
# Sort by modification time
|
||||
matches.sort(key=lambda x: x[1], reverse=True)
|
||||
file_paths = [path for path, _ in matches]
|
||||
|
||||
return "\n".join(file_paths)
|
||||
|
||||
def _handle_grep_search(
|
||||
self,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include: str | None,
|
||||
output_mode: str,
|
||||
state: AnthropicToolsState,
|
||||
) -> str:
|
||||
"""Handle grep search operation.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in.
|
||||
include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
|
||||
output_mode: Output format.
|
||||
state: The current agent state.
|
||||
|
||||
Returns:
|
||||
Search results formatted according to output_mode. Returns "No matches
|
||||
found" if no results.
|
||||
"""
|
||||
# Normalize base path
|
||||
base_path = path if path.startswith("/") else "/" + path
|
||||
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Search files
|
||||
files = cast("dict[str, Any]", state.get(self.state_key, {}))
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
for file_path, file_data in files.items():
|
||||
if not file_path.startswith(base_path):
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include:
|
||||
basename = Path(file_path).name
|
||||
if not _match_include_pattern(basename, include):
|
||||
continue
|
||||
|
||||
# Search file content
|
||||
for line_num, line in enumerate(file_data["content"], 1):
|
||||
if regex.search(line):
|
||||
if file_path not in results:
|
||||
results[file_path] = []
|
||||
results[file_path].append((line_num, line))
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
def _format_grep_results(
|
||||
self,
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
|
||||
@@ -9,11 +9,11 @@ license = { text = "MIT" }
|
||||
readme = "README.md"
|
||||
authors = []
|
||||
|
||||
version = "1.0.2"
|
||||
version = "1.0.4"
|
||||
requires-python = ">=3.10.0,<4.0.0"
|
||||
dependencies = [
|
||||
"anthropic>=0.69.0,<1.0.0",
|
||||
"langchain-core>=1.0.3,<2.0.0",
|
||||
"langchain-core>=1.0.4,<2.0.0",
|
||||
"pydantic>=2.7.4,<3.0.0",
|
||||
]
|
||||
|
||||
|
||||
@@ -3,67 +3,59 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
|
||||
pytest.importorskip(
|
||||
"anthropic", reason="Anthropic SDK is required for Claude middleware tests"
|
||||
)
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from langchain_anthropic.middleware.bash import ClaudeBashToolMiddleware
|
||||
|
||||
|
||||
def test_wrap_tool_call_handles_claude_bash(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_creates_bash_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that ClaudeBashToolMiddleware creates a tool named 'bash'."""
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
sentinel = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
|
||||
|
||||
monkeypatch.setattr(middleware, "_run_shell_tool", MagicMock(return_value=sentinel))
|
||||
monkeypatch.setattr(
|
||||
middleware, "_ensure_resources", MagicMock(return_value=MagicMock())
|
||||
# Should have exactly one tool registered (from parent)
|
||||
assert len(middleware.tools) == 1
|
||||
|
||||
# Tool is named "bash" (via tool_name parameter)
|
||||
bash_tool = middleware.tools[0]
|
||||
assert bash_tool.name == "bash"
|
||||
|
||||
|
||||
def test_replaces_tool_with_claude_descriptor() -> None:
|
||||
"""Test wrap_model_call replaces bash tool with Claude's bash descriptor."""
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
|
||||
# Create a mock request with the bash tool (inherited from parent)
|
||||
bash_tool = middleware.tools[0]
|
||||
request = ModelRequest(
|
||||
model=MagicMock(),
|
||||
system_prompt=None,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[bash_tool],
|
||||
response_format=None,
|
||||
state={"messages": []},
|
||||
runtime=MagicMock(),
|
||||
)
|
||||
|
||||
tool_call: ToolCall = {
|
||||
# Mock handler that captures the modified request
|
||||
captured_request = None
|
||||
|
||||
def handler(req: ModelRequest) -> MagicMock:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return MagicMock()
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
|
||||
# The bash tool should be replaced with Claude's native bash descriptor
|
||||
assert captured_request is not None
|
||||
assert len(captured_request.tools) == 1
|
||||
assert captured_request.tools[0] == {
|
||||
"type": "bash_20250124",
|
||||
"name": "bash",
|
||||
"args": {"command": "echo hi"},
|
||||
"id": "call-1",
|
||||
}
|
||||
request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=MagicMock(),
|
||||
state={},
|
||||
runtime=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
handler_called = False
|
||||
|
||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
return ToolMessage(content="should not be used", tool_call_id="call-1")
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
assert result is sentinel
|
||||
assert handler_called is False
|
||||
|
||||
|
||||
def test_wrap_tool_call_passes_through_other_tools(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
middleware = ClaudeBashToolMiddleware()
|
||||
tool_call: ToolCall = {"name": "other", "args": {}, "id": "call-2"}
|
||||
request = ToolCallRequest(
|
||||
tool_call=tool_call,
|
||||
tool=MagicMock(),
|
||||
state={},
|
||||
runtime=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
sentinel = ToolMessage(content="handled", tool_call_id="call-2", name="other")
|
||||
|
||||
def handler(_: ToolCallRequest) -> ToolMessage:
|
||||
return sentinel
|
||||
|
||||
result = middleware.wrap_tool_call(request, handler)
|
||||
assert result is sentinel
|
||||
|
||||
@@ -49,8 +49,10 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
# Call tool function directly (state is injected in real usage)
|
||||
result = middleware.glob_search.func(pattern="*.py", state=test_state) # type: ignore[attr-defined]
|
||||
# Call internal handler method directly
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="*.py", path="/", state=test_state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -82,7 +84,9 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
lines = result.split("\n")
|
||||
@@ -109,7 +113,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func( # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(
|
||||
pattern="**/*.py", path="/src", state=state
|
||||
)
|
||||
|
||||
@@ -132,7 +136,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.ts", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.ts", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No files found"
|
||||
@@ -157,7 +161,7 @@ class TestGlobSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="*.py", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="*.py", path="/", state=state)
|
||||
|
||||
lines = result.split("\n")
|
||||
# Most recent first
|
||||
@@ -193,7 +197,13 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"def \w+\(\):", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -216,8 +226,12 @@ class TestGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def", include="*.{py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def",
|
||||
path="/",
|
||||
include="*.{py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert result == "Invalid include pattern"
|
||||
@@ -241,8 +255,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"def \w+\(\):", output_mode="content", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"def \w+\(\):",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="content",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -271,8 +289,8 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern=r"TODO", output_mode="count", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO", path="/", include=None, output_mode="count", state=state
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -300,8 +318,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="import", include="*.py", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/",
|
||||
include="*.py",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -333,8 +355,12 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func( # type: ignore[attr-defined]
|
||||
pattern="const", include="*.{ts,tsx}", state=state
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="const",
|
||||
path="/",
|
||||
include="*.{ts,tsx}",
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
@@ -362,7 +388,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern="import", path="/src", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern="import",
|
||||
path="/src",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -383,7 +415,13 @@ class TestFilesystemGrepSearch:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "No matches found"
|
||||
@@ -397,7 +435,13 @@ class TestFilesystemGrepSearch:
|
||||
"text_editor_files": {},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"[unclosed", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"[unclosed",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Invalid regex pattern" in result
|
||||
@@ -428,7 +472,7 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.glob_search.func(pattern="**/*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_glob_search(pattern="**/*", path="/", state=state)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -457,7 +501,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r"TODO", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r"TODO",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
@@ -486,7 +536,13 @@ class TestSearchWithDifferentBackends:
|
||||
},
|
||||
}
|
||||
|
||||
result = middleware.grep_search.func(pattern=r".*", state=state) # type: ignore[attr-defined]
|
||||
result = middleware._handle_grep_search(
|
||||
pattern=r".*",
|
||||
path="/",
|
||||
include=None,
|
||||
output_mode="files_with_matches",
|
||||
state=state,
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "/src/main.py" in result
|
||||
|
||||
10
libs/partners/anthropic/uv.lock
generated
10
libs/partners/anthropic/uv.lock
generated
@@ -21,7 +21,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.71.0"
|
||||
version = "0.72.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -33,9 +33,9 @@ dependencies = [
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/82/4f/70682b068d897841f43223df82d96ec1d617435a8b759c4a2d901a50158b/anthropic-0.71.0.tar.gz", hash = "sha256:eb8e6fa86d049061b3ef26eb4cbae0174ebbff21affa6de7b3098da857d8de6a", size = 489102, upload-time = "2025-10-16T15:54:40.08Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dd/f3/feb750a21461090ecf48bbebcaa261cd09003cc1d14e2fa9643ad59edd4d/anthropic-0.72.1.tar.gz", hash = "sha256:a6d1d660e1f4af91dddc732f340786d19acaffa1ae8e69442e56be5fa6539d51", size = 415395, upload-time = "2025-11-11T16:53:29.001Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/77/073e8ac488f335aec7001952825275582fb8f433737e90f24eeef9d878f6/anthropic-0.71.0-py3-none-any.whl", hash = "sha256:85c5015fcdbdc728390f11b17642a65a4365d03b12b799b18b6cc57e71fdb327", size = 355035, upload-time = "2025-10-16T15:54:38.238Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/05/d9d45edad1aa28330cea09a3b35e1590f7279f91bb5ab5237c70a0884ea3/anthropic-0.72.1-py3-none-any.whl", hash = "sha256:81e73cca55e8924776c8c4418003defe6bf9eaf0cd92beb94c8dbf537b95316f", size = 357373, upload-time = "2025-11-11T16:53:27.438Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -477,7 +477,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "1.0.4"
|
||||
version = "1.0.6"
|
||||
source = { editable = "../../langchain_v1" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@@ -541,7 +541,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-anthropic"
|
||||
version = "1.0.2"
|
||||
version = "1.0.4"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
|
||||
Reference in New Issue
Block a user