fix: resumable shell, works w/ interrupts (#33978)

fixes https://github.com/langchain-ai/langchain/issues/33684

Now able to run this minimal snippet successfully

```py
import os

from langchain.agents import create_agent
from langchain.agents.middleware import (
    HostExecutionPolicy,
    HumanInTheLoopMiddleware,
    ShellToolMiddleware,
)
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command


shell_middleware = ShellToolMiddleware(
    workspace_root=os.getcwd(),
    env=os.environ,  # danger
    execution_policy=HostExecutionPolicy()
)

hil_middleware = HumanInTheLoopMiddleware(interrupt_on={"shell": True})

checkpointer = InMemorySaver()

agent = create_agent(
    "openai:gpt-4.1-mini",
    middleware=[shell_middleware, hil_middleware],
    checkpointer=checkpointer,
)

input_message = {"role": "user", "content": "run `which python`"}

config = {"configurable": {"thread_id": "1"}}

result = agent.invoke(
    {"messages": [input_message]},
    config=config,
    durability="exit",
)
```
This commit is contained in:
Sydney Runkle
2025-11-14 15:32:25 -05:00
committed by GitHub
parent 6aa3794b74
commit 9bd401a6d4
3 changed files with 136 additions and 38 deletions

View File

@@ -15,7 +15,7 @@ import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import ToolException
@@ -339,7 +339,7 @@ class _ShellToolInput(BaseModel):
restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema] = None
runtime: Annotated[Any, SkipJsonSchema()] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
@@ -445,7 +445,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._ensure_resources(runtime.state)
resources = self._get_or_create_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
@@ -491,7 +491,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Start the shell session and run startup commands."""
resources = self._create_resources()
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
@@ -500,7 +500,10 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
"""Run shutdown commands and release resources when an agent completes."""
resources = self._ensure_resources(state)
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
# Resources were never created, nothing to clean up
return
try:
self._run_shutdown_commands(resources.session)
finally:
@@ -510,17 +513,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
"""Get existing resources from state or create new ones if they don't exist.
This method enables resumability by checking if resources already exist in the state
(e.g., after an interrupt), and only creating new resources if they're not present.
Args:
state: The agent state which may contain shell session resources.
Returns:
Session resources, either retrieved from state or newly created.
"""
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
)
raise ToolException(msg)
return resources
if isinstance(resources, _SessionResources):
return resources
new_resources = self._create_resources()
# Cast needed to make state dict-like for mutation
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
return new_resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root

View File

@@ -32,7 +32,7 @@ def test_executes_command_and_persists_state(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
@@ -55,14 +55,14 @@ def test_restart_resets_session_environment(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
restart_message = middleware._run_shell_tool(
resources, {"restart": True}, tool_call_id=None
)
assert "restarted" in restart_message.lower()
resources = middleware._ensure_resources(state) # reacquire after restart
resources = middleware._get_or_create_resources(state) # reacquire after restart
result = middleware._run_shell_tool(
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
)
@@ -81,7 +81,7 @@ def test_truncation_indicator_present(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
assert "Output truncated" in result
finally:
@@ -98,7 +98,7 @@ def test_timeout_returns_error(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
start = time.monotonic()
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
elapsed = time.monotonic() - start
@@ -120,7 +120,7 @@ def test_redaction_policy_applies(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
message = middleware._run_shell_tool(
resources,
{"command": "printf 'Contact: user@example.com\\n'"},
@@ -222,7 +222,7 @@ def test_normalize_env_coercion(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
)
@@ -242,7 +242,7 @@ def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
@@ -267,7 +267,7 @@ def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo test"}, tool_call_id="test-id-123"
@@ -292,7 +292,7 @@ def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
@@ -319,7 +319,7 @@ def test_truncation_by_bytes(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
@@ -379,15 +379,6 @@ def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
middleware.after_agent(state, None)
def test_ensure_resources_missing_state(tmp_path: Path) -> None:
"""Test that _ensure_resources raises when resources are missing."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
state: AgentState = _empty_state()
with pytest.raises(ToolException, match="Shell session resources are unavailable"):
middleware._ensure_resources(state) # type: ignore[attr-defined]
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
"""Test that empty command output is replaced with '<no output>'."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
@@ -396,7 +387,7 @@ def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
@@ -419,7 +410,7 @@ def test_stderr_output_labeling(tmp_path: Path) -> None:
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo error >&2"}, tool_call_id=None
@@ -468,3 +459,98 @@ def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
asyncio.run(middleware.aafter_agent(state, None))
finally:
pass
def test_shell_middleware_resumable_after_interrupt(tmp_path: Path) -> None:
"""Test that shell middleware is resumable after an interrupt.
This test simulates a scenario where:
1. The middleware creates a shell session
2. A command is executed
3. The agent is interrupted (state is preserved)
4. The agent resumes with the same state
5. The shell session is reused (not recreated)
"""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
# Simulate first execution (before interrupt)
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get the resources and verify they exist
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
initial_session = resources.session
initial_tempdir = resources.tempdir
# Execute a command to set state
middleware._run_shell_tool(resources, {"command": "export TEST_VAR=hello"}, tool_call_id=None)
# Simulate interrupt - state is preserved, but we don't call after_agent
# In a real scenario, the state would be checkpointed here
# Simulate resumption - call before_agent again with same state
# This should reuse existing resources, not create new ones
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
# Get resources again - should be the same session
resumed_resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Verify the session was reused (same object reference)
assert resumed_resources.session is initial_session
assert resumed_resources.tempdir is initial_tempdir
# Verify the session state persisted (environment variable still set)
result = middleware._run_shell_tool(
resumed_resources, {"command": "echo ${TEST_VAR:-unset}"}, tool_call_id=None
)
assert "hello" in result
assert "unset" not in result
# Clean up
middleware.after_agent(state, None)
def test_get_or_create_resources_creates_when_missing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources creates resources when they don't exist."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# State has no resources initially
assert "shell_session_resources" not in state
# Call _get_or_create_resources - should create new resources
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert isinstance(resources, _SessionResources)
assert resources.session is not None
assert state.get("shell_session_resources") is resources
# Clean up
resources._finalizer()
def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None:
"""Test that _get_or_create_resources reuses existing resources."""
workspace = tmp_path / "workspace"
middleware = ShellToolMiddleware(workspace_root=workspace)
state: AgentState = _empty_state()
# Create resources first time
resources1 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
# Call again - should return the same resources
resources2 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
assert resources1 is resources2
assert resources1.session is resources2.session
# Clean up
resources1._finalizer()

View File

@@ -2174,7 +2174,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.0.4"
version = "1.0.5"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },