mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
fix: resumable shell, works w/ interrupts (#33978)
fixes https://github.com/langchain-ai/langchain/issues/33684 Now able to run this minimal snippet successfully ```py import os from langchain.agents import create_agent from langchain.agents.middleware import ( HostExecutionPolicy, HumanInTheLoopMiddleware, ShellToolMiddleware, ) from langgraph.checkpoint.memory import InMemorySaver from langgraph.types import Command shell_middleware = ShellToolMiddleware( workspace_root=os.getcwd(), env=os.environ, # danger execution_policy=HostExecutionPolicy() ) hil_middleware = HumanInTheLoopMiddleware(interrupt_on={"shell": True}) checkpointer = InMemorySaver() agent = create_agent( "openai:gpt-4.1-mini", middleware=[shell_middleware, hil_middleware], checkpointer=checkpointer, ) input_message = {"role": "user", "content": "run `which python`"} config = {"configurable": {"thread_id": "1"}} result = agent.invoke( {"messages": [input_message]}, config=config, durability="exit", ) ```
This commit is contained in:
@@ -15,7 +15,7 @@ import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools.base import ToolException
|
||||
@@ -339,7 +339,7 @@ class _ShellToolInput(BaseModel):
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema] = None
|
||||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
@@ -445,7 +445,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._ensure_resources(runtime.state)
|
||||
resources = self._get_or_create_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
@@ -491,7 +491,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Start the shell session and run startup commands."""
|
||||
resources = self._create_resources()
|
||||
resources = self._get_or_create_resources(state)
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
@@ -500,7 +500,10 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
|
||||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
|
||||
"""Run shutdown commands and release resources when an agent completes."""
|
||||
resources = self._ensure_resources(state)
|
||||
resources = state.get("shell_session_resources")
|
||||
if not isinstance(resources, _SessionResources):
|
||||
# Resources were never created, nothing to clean up
|
||||
return
|
||||
try:
|
||||
self._run_shutdown_commands(resources.session)
|
||||
finally:
|
||||
@@ -510,17 +513,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
||||
"""Get existing resources from state or create new ones if they don't exist.
|
||||
|
||||
This method enables resumability by checking if resources already exist in the state
|
||||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||||
|
||||
Args:
|
||||
state: The agent state which may contain shell session resources.
|
||||
|
||||
Returns:
|
||||
Session resources, either retrieved from state or newly created.
|
||||
"""
|
||||
resources = state.get("shell_session_resources")
|
||||
if resources is not None and not isinstance(resources, _SessionResources):
|
||||
resources = None
|
||||
if resources is None:
|
||||
msg = (
|
||||
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
|
||||
"before invoking the shell tool."
|
||||
)
|
||||
raise ToolException(msg)
|
||||
return resources
|
||||
if isinstance(resources, _SessionResources):
|
||||
return resources
|
||||
|
||||
new_resources = self._create_resources()
|
||||
# Cast needed to make state dict-like for mutation
|
||||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||||
return new_resources
|
||||
|
||||
def _create_resources(self) -> _SessionResources:
|
||||
workspace = self._workspace_root
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_executes_command_and_persists_state(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
middleware._run_shell_tool(resources, {"command": "cd /"}, tool_call_id=None)
|
||||
result = middleware._run_shell_tool(resources, {"command": "pwd"}, tool_call_id=None)
|
||||
@@ -55,14 +55,14 @@ def test_restart_resets_session_environment(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
middleware._run_shell_tool(resources, {"command": "export FOO=bar"}, tool_call_id=None)
|
||||
restart_message = middleware._run_shell_tool(
|
||||
resources, {"restart": True}, tool_call_id=None
|
||||
)
|
||||
assert "restarted" in restart_message.lower()
|
||||
resources = middleware._ensure_resources(state) # reacquire after restart
|
||||
resources = middleware._get_or_create_resources(state) # reacquire after restart
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo ${FOO:-unset}"}, tool_call_id=None
|
||||
)
|
||||
@@ -81,7 +81,7 @@ def test_truncation_indicator_present(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
result = middleware._run_shell_tool(resources, {"command": "seq 1 20"}, tool_call_id=None)
|
||||
assert "Output truncated" in result
|
||||
finally:
|
||||
@@ -98,7 +98,7 @@ def test_timeout_returns_error(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
start = time.monotonic()
|
||||
result = middleware._run_shell_tool(resources, {"command": "sleep 2"}, tool_call_id=None)
|
||||
elapsed = time.monotonic() - start
|
||||
@@ -120,7 +120,7 @@ def test_redaction_policy_applies(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
message = middleware._run_shell_tool(
|
||||
resources,
|
||||
{"command": "printf 'Contact: user@example.com\\n'"},
|
||||
@@ -222,7 +222,7 @@ def test_normalize_env_coercion(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
|
||||
)
|
||||
@@ -242,7 +242,7 @@ def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
with pytest.raises(ToolException, match="expects a 'command' string"):
|
||||
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
|
||||
@@ -267,7 +267,7 @@ def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo test"}, tool_call_id="test-id-123"
|
||||
@@ -292,7 +292,7 @@ def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
@@ -319,7 +319,7 @@ def test_truncation_by_bytes(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
|
||||
@@ -379,15 +379,6 @@ def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_ensure_resources_missing_state(tmp_path: Path) -> None:
|
||||
"""Test that _ensure_resources raises when resources are missing."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
with pytest.raises(ToolException, match="Shell session resources are unavailable"):
|
||||
middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
|
||||
"""Test that empty command output is replaced with '<no output>'."""
|
||||
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
|
||||
@@ -396,7 +387,7 @@ def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources,
|
||||
@@ -419,7 +410,7 @@ def test_stderr_output_labeling(tmp_path: Path) -> None:
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
result = middleware._run_shell_tool(
|
||||
resources, {"command": "echo error >&2"}, tool_call_id=None
|
||||
@@ -468,3 +459,98 @@ def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
|
||||
asyncio.run(middleware.aafter_agent(state, None))
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def test_shell_middleware_resumable_after_interrupt(tmp_path: Path) -> None:
|
||||
"""Test that shell middleware is resumable after an interrupt.
|
||||
|
||||
This test simulates a scenario where:
|
||||
1. The middleware creates a shell session
|
||||
2. A command is executed
|
||||
3. The agent is interrupted (state is preserved)
|
||||
4. The agent resumes with the same state
|
||||
5. The shell session is reused (not recreated)
|
||||
"""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
# Simulate first execution (before interrupt)
|
||||
state: AgentState = _empty_state()
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Get the resources and verify they exist
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
initial_session = resources.session
|
||||
initial_tempdir = resources.tempdir
|
||||
|
||||
# Execute a command to set state
|
||||
middleware._run_shell_tool(resources, {"command": "export TEST_VAR=hello"}, tool_call_id=None)
|
||||
|
||||
# Simulate interrupt - state is preserved, but we don't call after_agent
|
||||
# In a real scenario, the state would be checkpointed here
|
||||
|
||||
# Simulate resumption - call before_agent again with same state
|
||||
# This should reuse existing resources, not create new ones
|
||||
updates = middleware.before_agent(state, None)
|
||||
if updates:
|
||||
state.update(updates)
|
||||
|
||||
# Get resources again - should be the same session
|
||||
resumed_resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
# Verify the session was reused (same object reference)
|
||||
assert resumed_resources.session is initial_session
|
||||
assert resumed_resources.tempdir is initial_tempdir
|
||||
|
||||
# Verify the session state persisted (environment variable still set)
|
||||
result = middleware._run_shell_tool(
|
||||
resumed_resources, {"command": "echo ${TEST_VAR:-unset}"}, tool_call_id=None
|
||||
)
|
||||
assert "hello" in result
|
||||
assert "unset" not in result
|
||||
|
||||
# Clean up
|
||||
middleware.after_agent(state, None)
|
||||
|
||||
|
||||
def test_get_or_create_resources_creates_when_missing(tmp_path: Path) -> None:
|
||||
"""Test that _get_or_create_resources creates resources when they don't exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# State has no resources initially
|
||||
assert "shell_session_resources" not in state
|
||||
|
||||
# Call _get_or_create_resources - should create new resources
|
||||
resources = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
assert isinstance(resources, _SessionResources)
|
||||
assert resources.session is not None
|
||||
assert state.get("shell_session_resources") is resources
|
||||
|
||||
# Clean up
|
||||
resources._finalizer()
|
||||
|
||||
|
||||
def test_get_or_create_resources_reuses_existing(tmp_path: Path) -> None:
|
||||
"""Test that _get_or_create_resources reuses existing resources."""
|
||||
workspace = tmp_path / "workspace"
|
||||
middleware = ShellToolMiddleware(workspace_root=workspace)
|
||||
|
||||
state: AgentState = _empty_state()
|
||||
|
||||
# Create resources first time
|
||||
resources1 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
# Call again - should return the same resources
|
||||
resources2 = middleware._get_or_create_resources(state) # type: ignore[attr-defined]
|
||||
|
||||
assert resources1 is resources2
|
||||
assert resources1.session is resources2.session
|
||||
|
||||
# Clean up
|
||||
resources1._finalizer()
|
||||
|
||||
2
libs/langchain_v1/uv.lock
generated
2
libs/langchain_v1/uv.lock
generated
@@ -2174,7 +2174,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
Reference in New Issue
Block a user