going back to sequence based approach

This commit is contained in:
Sydney Runkle
2025-09-16 15:16:20 -04:00
parent 4f33433ce5
commit 6745b79755
2 changed files with 38 additions and 50 deletions

View File

@@ -1,6 +1,6 @@
"""Human in the loop middleware."""
from typing import Any, Literal, cast
from typing import Any, Literal
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.types import interrupt
@@ -40,7 +40,6 @@ class HumanInterrupt(TypedDict):
args: Arguments for the action (tool kwargs)
config: Configuration defining what actions are allowed
description: Optional detailed description of what input is needed
tool_call_id: Identifier for the associated tool call
Example:
# Send the interrupt request and get the response
@@ -54,7 +53,6 @@ class HumanInterrupt(TypedDict):
allow_approve=True, # Allow direct acceptance
),
description="Please review the command before execution",
tool_call_id="call_123",
)
# Send the interrupt request and get the response
response = interrupt([request])[0]
@@ -65,21 +63,18 @@ class HumanInterrupt(TypedDict):
args: dict
config: HumanInterruptConfig
description: str | None
tool_call_id: str
class ApprovePayload(TypedDict):
"""Human chose to approve the current state without changes."""
type: Literal["approve"]
tool_call_id: str
class IgnorePayload(TypedDict):
"""Human chose to ignore/skip the current step with optional tool message customization."""
type: Literal["ignore"]
tool_call_id: str
tool_message: NotRequired[str | ToolMessage]
@@ -87,7 +82,6 @@ class ResponsePayload(TypedDict):
"""Human provided text feedback or instructions."""
type: Literal["response"]
tool_call_id: str
tool_message: str | ToolMessage
@@ -95,7 +89,6 @@ class EditPayload(TypedDict):
"""Human chose to edit/modify the current state/content."""
type: Literal["edit"]
tool_call_id: str
action: str
args: dict
@@ -173,7 +166,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
# Create interrupt requests for all tools that need approval
interrupt_requests: list[HumanInterrupt] = []
for tool_call in interrupt_tool_calls.values():
interrupt_tool_calls_list = list(interrupt_tool_calls.values())
for tool_call in interrupt_tool_calls_list:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
@@ -184,22 +178,21 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
"args": tool_args,
"config": tool_config,
"description": description,
# ids should always be present on tool calls
"tool_call_id": cast("str", tool_call["id"]),
}
interrupt_requests.append(request)
responses: list[HumanResponse] = interrupt(interrupt_requests)
for response in responses:
try:
tool_call = interrupt_tool_calls[response["tool_call_id"]]
except KeyError:
msg = (
f"Unexpected human response: {response}. "
f"Expected one with `'tool_call_id'` in {list(interrupt_tool_calls.keys())}."
)
raise ValueError(msg)
# Validate that the number of responses matches the number of interrupt tool calls
if len(responses) != len(interrupt_tool_calls_list):
msg = (
f"Number of human responses ({len(responses)}) does not match "
f"number of hanging tool calls ({len(interrupt_tool_calls_list)})."
)
raise ValueError(msg)
for i, response in enumerate(responses):
tool_call = interrupt_tool_calls_list[i]
tool_config = self.tool_configs[tool_call["name"]]

View File

@@ -416,7 +416,7 @@ def test_human_in_the_loop_middleware_single_tool_accept() -> None:
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return [{"type": "approve", "args": None, "tool_call_id": "1"}]
return [{"type": "approve"}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state)
@@ -449,7 +449,6 @@ def test_human_in_the_loop_middleware_single_tool_edit() -> None:
"type": "edit",
"action": "test_tool",
"args": {"input": "edited"},
"tool_call_id": "1",
}
]
@@ -479,7 +478,7 @@ def test_human_in_the_loop_middleware_single_tool_ignore() -> None:
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_ignore(requests):
return [{"type": "ignore", "args": None, "tool_call_id": "1"}]
return [{"type": "ignore"}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore):
result = middleware.after_model(state)
@@ -511,7 +510,7 @@ def test_human_in_the_loop_middleware_single_tool_response() -> None:
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_response(requests):
return [{"type": "response", "tool_message": "Custom response", "tool_call_id": "1"}]
return [{"type": "response", "tool_message": "Custom response"}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
@@ -552,8 +551,8 @@ def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
def mock_mixed_responses(requests):
return [
{"type": "approve", "args": None, "tool_call_id": "1"},
{"type": "ignore", "args": None, "tool_call_id": "2"},
{"type": "approve"},
{"type": "ignore"},
]
with patch(
@@ -605,13 +604,11 @@ def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
"type": "edit",
"action": "get_forecast",
"args": {"location": "New York"},
"tool_call_id": "1",
},
{
"type": "edit",
"action": "get_temperature",
"args": {"location": "New York"},
"tool_call_id": "2",
},
]
@@ -657,12 +654,10 @@ def test_human_in_the_loop_middleware_multiple_tools_response_types() -> None:
{
"type": "response",
"tool_message": "actually, please get the conditions in NYC",
"tool_call_id": "1",
},
{
"type": "response",
"tool_message": "actually, please get the temperature in NYC",
"tool_call_id": "2",
},
]
@@ -704,12 +699,12 @@ def test_human_in_the_loop_middleware_unknown_response_type() -> None:
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_unknown(requests):
return [{"type": "unknown", "args": None, "tool_call_id": "1"}]
return [{"type": "unknown"}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
with pytest.raises(
ValueError,
match="Unexpected human response: {'type': 'unknown', 'args': None, 'tool_call_id': '1'}. Response type 'unknown' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'edit', 'response', 'ignore'\\] based on the tool's interrupt configuration.",
match="Unexpected human response: {'type': 'unknown'}. Response type 'unknown' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'edit', 'response', 'ignore'\\] based on the tool's interrupt configuration.",
):
middleware.after_model(state)
@@ -731,7 +726,7 @@ def test_human_in_the_loop_middleware_disallowed_response_type() -> None:
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_disallowed_response(requests):
return [{"type": "response", "args": "Custom response", "tool_call_id": "1"}]
return [{"type": "response", "tool_message": "Custom response"}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
@@ -739,7 +734,7 @@ def test_human_in_the_loop_middleware_disallowed_response_type() -> None:
):
with pytest.raises(
ValueError,
match="Unexpected human response: {'type': 'response', 'args': 'Custom response', 'tool_call_id': '1'}. Response type 'response' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'ignore'\\] based on the tool's interrupt configuration.",
match="Unexpected human response: {'type': 'response', 'tool_message': 'Custom response'}. Response type 'response' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'ignore'\\] based on the tool's interrupt configuration.",
):
middleware.after_model(state)
@@ -766,7 +761,6 @@ def test_human_in_the_loop_middleware_disallowed_edit_type() -> None:
"type": "edit",
"action": "test_tool",
"args": {"input": "edited"},
"tool_call_id": "1",
}
]
@@ -775,7 +769,7 @@ def test_human_in_the_loop_middleware_disallowed_edit_type() -> None:
):
with pytest.raises(
ValueError,
match="Unexpected human response: {'type': 'edit', 'action': 'test_tool', 'args': {'input': 'edited'}, 'tool_call_id': '1'}. Response type 'edit' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'response', 'ignore'\\] based on the tool's interrupt configuration.",
match="Unexpected human response: {'type': 'edit', 'action': 'test_tool', 'args': {'input': 'edited'}}. Response type 'edit' is not allowed for tool 'test_tool'. Expected one with `'type'` in \\['accept', 'response', 'ignore'\\] based on the tool's interrupt configuration.",
):
middleware.after_model(state)
@@ -800,7 +794,7 @@ def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> Non
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
def mock_accept(requests):
return [{"type": "approve", "args": None, "tool_call_id": "2"}]
return [{"type": "approve"}]
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
result = middleware.after_model(state)
@@ -839,8 +833,8 @@ def test_human_in_the_loop_middleware_all_ignored() -> None:
def mock_all_ignore(requests):
return [
{"type": "ignore", "args": None, "tool_call_id": "1"},
{"type": "ignore", "args": None, "tool_call_id": "2"},
{"type": "ignore"},
{"type": "ignore"},
]
with patch(
@@ -883,7 +877,7 @@ def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
def mock_capture_requests(requests):
captured_requests.extend(requests)
return [{"type": "approve", "args": None, "tool_call_id": "1"}]
return [{"type": "approve"}]
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
@@ -919,7 +913,7 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
# Test approve
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve", "tool_call_id": "1"}],
return_value=[{"type": "approve"}],
):
result = middleware.after_model(state)
assert result is not None
@@ -933,7 +927,6 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
return_value=[
{
"type": "edit",
"tool_call_id": "1",
"action": "test_tool",
"args": {"input": "edited"},
}
@@ -948,7 +941,7 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
# Test ignore
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "ignore", "tool_call_id": "1"}],
return_value=[{"type": "ignore"}],
):
result = middleware.after_model(state)
assert result is not None
@@ -963,7 +956,7 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
# Test response
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "response", "tool_call_id": "1", "tool_message": "Custom response"}],
return_value=[{"type": "response", "tool_message": "Custom response"}],
):
result = middleware.after_model(state)
assert result is not None
@@ -981,8 +974,8 @@ def test_human_in_the_loop_middleware_boolean_configs() -> None:
assert result is None
def test_human_in_the_loop_middleware_missing_tool_call_id() -> None:
"""Test that missing tool call ID in resume raises an error."""
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
"""Test that sequence mismatch in resume raises an error."""
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
ai_message = AIMessage(
@@ -991,23 +984,25 @@ def test_human_in_the_loop_middleware_missing_tool_call_id() -> None:
)
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
# Test with too few responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve"}], # Missing tool_call_id
return_value=[], # No responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: \{'type': 'approve'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state)
# Test with too many responses
with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
return_value=[{"type": "approve", "tool_call_id": "nonexistent"}],
return_value=[{"type": "approve"}, {"type": "approve"}], # 2 responses for 1 tool call
):
with pytest.raises(
ValueError,
match=r"Unexpected human response: \{'type': 'approve', 'tool_call_id': 'nonexistent'\}\. Expected one with `'tool_call_id'` in \['1'\]\.",
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
):
middleware.after_model(state)