diff --git a/libs/langchain_v1/langchain/agents/interrupt.py b/libs/langchain_v1/langchain/agents/interrupt.py deleted file mode 100644 index cbba63b51f5..00000000000 --- a/libs/langchain_v1/langchain/agents/interrupt.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Interrupt types to use with agent inbox like setups.""" - -from typing import Literal - -from typing_extensions import TypedDict - - -class HumanInterruptConfig(TypedDict): - """Configuration that defines what actions are allowed for a human interrupt. - - This controls the available interaction options when the graph is paused for human input. - - Attributes: - allow_ignore: Whether the human can choose to ignore/skip the current step - allow_respond: Whether the human can provide a text response/feedback - allow_edit: Whether the human can edit the provided content/state - allow_accept: Whether the human can accept/approve the current state - """ - - allow_ignore: bool - allow_respond: bool - allow_edit: bool - allow_accept: bool - - -class ActionRequest(TypedDict): - """Represents a request for human action within the graph execution. - - Contains the action type and any associated arguments needed for the action. - - Attributes: - action: The type or name of action being requested (e.g., "Approve XYZ action") - args: Key-value pairs of arguments needed for the action - """ - - action: str - args: dict - - -class HumanInterrupt(TypedDict): - """Represents an interrupt triggered by the graph that requires human intervention. - - This is passed to the `interrupt` function when execution is paused for human input. - - Attributes: - action_request: The specific action being requested from the human - config: Configuration defining what actions are allowed - description: Optional detailed description of what input is needed - - Example: - ```python - # Extract a tool call from the state and create an interrupt request - request = HumanInterrupt( - action_request=ActionRequest( - action="run_command", # The action being requested - args={"command": "ls", "args": ["-l"]}, # Arguments for the action - ), - config=HumanInterruptConfig( - allow_ignore=True, # Allow skipping this step - allow_respond=True, # Allow text feedback - allow_edit=False, # Don't allow editing - allow_accept=True, # Allow direct acceptance - ), - description="Please review the command before execution", - ) - # Send the interrupt request and get the response - response = interrupt([request])[0] - ``` - """ - - action_request: ActionRequest - config: HumanInterruptConfig - description: str | None - - -class HumanResponse(TypedDict): - """Human response. - - The response provided by a human to an interrupt, - which is returned when graph execution resumes. - - Attributes: - type: The type of response: - - - "accept": Approves the current state without changes - - "ignore": Skips/ignores the current step - - "response": Provides text feedback or instructions - - "edit": Modifies the current state/content - args: The response payload: - - - None: For ignore/accept actions - - str: For text responses - - ActionRequest: For edit actions with updated content - """ - - type: Literal["accept", "ignore", "response", "edit"] - args: None | str | ActionRequest diff --git a/libs/langchain_v1/langchain/agents/middleware/_utils.py b/libs/langchain_v1/langchain/agents/middleware/_utils.py deleted file mode 100644 index 2debfd4c884..00000000000 --- a/libs/langchain_v1/langchain/agents/middleware/_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Utility functions for middleware.""" - -from typing import Any - - -def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]: - """Generate tool messages for model behavior correction.""" - return [ - {"role": "tool", "content": content, "tool_call_id": tool_call["id"]} - for tool_call in tool_calls - ] diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 4718e704685..46ac1a75990 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -1,19 +1,110 @@ """Human in the loop middleware.""" -from typing import Any +from typing import Any, Literal -from langgraph.prebuilt.interrupt import ( - ActionRequest, - HumanInterrupt, - HumanInterruptConfig, - HumanResponse, -) +from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langgraph.types import interrupt +from typing_extensions import NotRequired, TypedDict -from langchain.agents.middleware._utils import _generate_correction_tool_messages from langchain.agents.middleware.types import AgentMiddleware, AgentState -ToolInterruptConfig = dict[str, HumanInterruptConfig] + +class HumanInTheLoopConfig(TypedDict): + """Configuration that defines what actions are allowed for a human interrupt. + + This controls the available interaction options when the graph is paused for human input. + """ + + allow_accept: NotRequired[bool] + """Whether the human can approve the current action without changes.""" + allow_edit: NotRequired[bool] + """Whether the human can approve the current action with edited content.""" + allow_respond: NotRequired[bool] + """Whether the human can reject the current action with feedback.""" + + +class ActionRequest(TypedDict): + """Represents a request with a name and arguments.""" + + action: str + """The type or name of action being requested (e.g., "add_numbers").""" + args: dict + """Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2}).""" + + +class HumanInTheLoopRequest(TypedDict): + """Represents an interrupt triggered by the graph that requires human intervention. + + Example: + ```python + # Extract a tool call from the state and create an interrupt request + request = HumanInterrupt( + action_request=ActionRequest( + action="run_command", # The action being requested + args={"command": "ls", "args": ["-l"]}, # Arguments for the action + ), + config=HumanInTheLoopConfig( + allow_accept=True, # Allow approval + allow_respond=True, # Allow rejection with feedback + allow_edit=False, # Don't allow approval with edits + ), + description="Please review the command before execution", + ) + # Send the interrupt request and get the response + response = interrupt([request])[0] + ``` + """ + + action_request: ActionRequest + """The specific action being requested from the human.""" + config: HumanInTheLoopConfig + """Configuration defining what response types are allowed.""" + description: str | None + """Optional detailed description of what input is needed.""" + + +class AcceptPayload(TypedDict): + """Response when a human approves the action.""" + + type: Literal["accept"] + """The type of response when a human approves the action.""" + + +class ResponsePayload(TypedDict): + """Response when a human rejects the action.""" + + type: Literal["response"] + """The type of response when a human rejects the action.""" + + args: NotRequired[str] + """The message to be sent to the model explaining why the action was rejected.""" + + +class EditPayload(TypedDict): + """Response when a human edits the action.""" + + type: Literal["edit"] + """The type of response when a human edits the action.""" + + args: ActionRequest + """The action request with the edited content.""" + + +HumanInTheLoopResponse = AcceptPayload | ResponsePayload | EditPayload +"""Aggregated response type for all possible human in the loop responses.""" + + +class ToolConfig(TypedDict): + """Configuration for a tool requiring human in the loop.""" + + allow_accept: NotRequired[bool] + """Whether the human can approve the current action without changes.""" + allow_edit: NotRequired[bool] + """Whether the human can approve the current action with edited content.""" + allow_respond: NotRequired[bool] + """Whether the human can reject the current action with feedback.""" + description: NotRequired[str] + """The description attached to the request for human input.""" class HumanInTheLoopMiddleware(AgentMiddleware): @@ -21,18 +112,36 @@ class HumanInTheLoopMiddleware(AgentMiddleware): def __init__( self, - tool_configs: ToolInterruptConfig, - message_prefix: str = "Tool execution requires approval", + tool_configs: dict[str, bool | ToolConfig], + *, + description_prefix: str = "Tool execution requires approval", ) -> None: """Initialize the human in the loop middleware. Args: - tool_configs: The tool interrupt configs to use for the middleware. - message_prefix: The message prefix to use when constructing interrupt content. + tool_configs: Mapping of tool name to allowed actions. + If a tool doesn't have an entry, it's auto-approved by default. + * `True` indicates all actions are allowed: accept, edit, and respond. + * `False` indicates that the tool is auto-approved. + * ToolConfig indicates the specific actions allowed for this tool. + description_prefix: The prefix to use when constructing action requests. + This is used to provide context about the tool call and the action being requested. + Not used if a tool has a description in its ToolConfig. """ super().__init__() - self.tool_configs = tool_configs - self.message_prefix = message_prefix + resolved_tool_configs: dict[str, ToolConfig] = {} + for tool_name, tool_config in tool_configs.items(): + if isinstance(tool_config, bool): + if tool_config is True: + resolved_tool_configs[tool_name] = ToolConfig( + allow_accept=True, + allow_edit=True, + allow_respond=True, + ) + else: + resolved_tool_configs[tool_name] = tool_config + self.tool_configs = resolved_tool_configs + self.description_prefix = description_prefix def after_model(self, state: AgentState) -> dict[str, Any] | None: """Trigger HITL flows for relevant tool calls after an AIMessage.""" @@ -40,96 +149,105 @@ class HumanInTheLoopMiddleware(AgentMiddleware): if not messages: return None - last_message = messages[-1] - - if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: + last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None) + if not last_ai_msg or not last_ai_msg.tool_calls: return None # Separate tool calls that need interrupts from those that don't - interrupt_tool_calls = [] + hitl_tool_calls: list[ToolCall] = [] auto_approved_tool_calls = [] - for tool_call in last_message.tool_calls: - tool_name = tool_call["name"] - if tool_name in self.tool_configs: - interrupt_tool_calls.append(tool_call) - else: - auto_approved_tool_calls.append(tool_call) + for tool_call in last_ai_msg.tool_calls: + hitl_tool_calls.append(tool_call) if tool_call[ + "name" + ] in self.tool_configs else auto_approved_tool_calls.append(tool_call) # If no interrupts needed, return early - if not interrupt_tool_calls: + if not hitl_tool_calls: return None - approved_tool_calls = auto_approved_tool_calls.copy() + # Process all tool calls that require interrupts + approved_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy() + artificial_tool_messages: list[ToolMessage] = [] - # Right now, we do not support multiple tool calls with interrupts - if len(interrupt_tool_calls) > 1: - tool_names = [t["name"] for t in interrupt_tool_calls] - msg = ( - f"Called the following tools which require interrupts: {tool_names}\n\n" - "You may only call ONE tool that requires an interrupt at a time" + # Create interrupt requests for all tools that need approval + hitl_requests: list[HumanInTheLoopRequest] = [] + for tool_call in hitl_tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + config = self.tool_configs[tool_name] + description = ( + config.get("description") + or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" ) - return { - "messages": _generate_correction_tool_messages(msg, last_message.tool_calls), - "jump_to": "model", - } - # Right now, we do not support interrupting a tool call if other tool calls exist - if auto_approved_tool_calls: - tool_names = [t["name"] for t in interrupt_tool_calls] + request: HumanInTheLoopRequest = { + "action_request": ActionRequest( + action=tool_name, + args=tool_args, + ), + "config": config, + "description": description, + } + hitl_requests.append(request) + + responses: list[HumanInTheLoopResponse] = interrupt(hitl_requests) + + # Validate that the number of responses matches the number of interrupt tool calls + if (responses_len := len(responses)) != (hitl_tool_calls_len := len(hitl_tool_calls)): msg = ( - f"Called the following tools which require interrupts: {tool_names}. " - "You also called other tools that do not require interrupts. " - "If you call a tool that requires and interrupt, you may ONLY call that tool." + f"Number of human responses ({responses_len}) does not match " + f"number of hanging tool calls ({hitl_tool_calls_len})." ) - return { - "messages": _generate_correction_tool_messages(msg, last_message.tool_calls), - "jump_to": "model", - } - - # Only one tool call will need interrupts - tool_call = interrupt_tool_calls[0] - tool_name = tool_call["name"] - tool_args = tool_call["args"] - description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}" - tool_config = self.tool_configs[tool_name] - - request: HumanInterrupt = { - "action_request": ActionRequest( - action=tool_name, - args=tool_args, - ), - "config": tool_config, - "description": description, - } - - responses: list[HumanResponse] = interrupt([request]) - response = responses[0] - - if response["type"] == "accept": - approved_tool_calls.append(tool_call) - elif response["type"] == "edit": - edited: ActionRequest = response["args"] # type: ignore[assignment] - new_tool_call = { - "type": "tool_call", - "name": tool_call["name"], - "args": edited["args"], - "id": tool_call["id"], - } - approved_tool_calls.append(new_tool_call) - elif response["type"] == "ignore": - return {"jump_to": "__end__"} - elif response["type"] == "response": - tool_message = { - "role": "tool", - "tool_call_id": tool_call["id"], - "content": response["args"], - } - return {"messages": [tool_message], "jump_to": "model"} - else: - msg = f"Unknown response type: {response['type']}" raise ValueError(msg) - last_message.tool_calls = approved_tool_calls + for i, response in enumerate(responses): + tool_call = hitl_tool_calls[i] + config = self.tool_configs[tool_call["name"]] - return {"messages": [last_message]} + if response["type"] == "accept" and config.get("allow_accept"): + approved_tool_calls.append(tool_call) + elif response["type"] == "edit" and config.get("allow_edit"): + edited_action = response["args"] + approved_tool_calls.append( + ToolCall( + type="tool_call", + name=edited_action["action"], + args=edited_action["args"], + id=tool_call["id"], + ) + ) + elif response["type"] == "response" and config.get("allow_respond"): + # Create a tool message with the human's text response + content = response.get("args") or ( + f"User rejected the tool call for `{tool_call['name']}` " + f"with id {tool_call['id']}" + ) + tool_message = ToolMessage( + content=content, + name=tool_call["name"], + tool_call_id=tool_call["id"], + status="error", + ) + artificial_tool_messages.append(tool_message) + else: + allowed_actions = [ + action + for action in ["accept", "edit", "response"] + if config.get(f"allow_{'respond' if action == 'response' else action}") + ] + msg = ( + f"Unexpected human response: {response}. " + f"Response action '{response.get('type')}' " + f"is not allowed for tool '{tool_call['name']}'. " + f"Expected one of {allowed_actions} based on the tool's configuration." + ) + raise ValueError(msg) + + # Update the AI message to only include approved tool calls + last_ai_msg.tool_calls = approved_tool_calls + + if len(approved_tool_calls) > 0: + return {"messages": [last_ai_msg, *artificial_tool_messages]} + + return {"jump_to": "model", "messages": artificial_tool_messages} diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 53f94be3d60..4caa6c8af62 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -2,14 +2,15 @@ import itertools from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, cast from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, SystemMessage, ToolMessage +from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langgraph.constants import END, START from langgraph.graph.state import StateGraph +from langgraph.types import Send from langgraph.typing import ContextT from typing_extensions import TypedDict, TypeVar @@ -211,15 +212,13 @@ def create_agent( # noqa: PLR0915 context_schema=context_schema, ) - def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]: + def _handle_model_output(output: AIMessage) -> dict[str, Any]: """Handle model output including structured responses.""" # Handle structured output with native strategy if isinstance(response_format, ProviderStrategy): if not output.tool_calls and native_output_binding: structured_response = native_output_binding.parse(output) return {"messages": [output], "response": structured_response} - if state.get("response") is not None: - return {"messages": [output], "response": None} return {"messages": [output]} # Handle structured output with tools strategy @@ -297,9 +296,6 @@ def create_agent( # noqa: PLR0915 ], } - # Standard response handling - if state.get("response") is not None: - return {"messages": [output], "response": None} return {"messages": [output]} def _get_bound_model(request: ModelRequest) -> Runnable: @@ -346,7 +342,7 @@ def create_agent( # noqa: PLR0915 messages = [SystemMessage(request.system_prompt), *messages] output = model_.invoke(messages) - return _handle_model_output(state, output) + return _handle_model_output(output) async def amodel_request(state: dict[str, Any]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" @@ -373,7 +369,7 @@ def create_agent( # noqa: PLR0915 messages = [SystemMessage(request.system_prompt), *messages] output = await model_.ainvoke(messages) - return _handle_model_output(state, output) + return _handle_model_output(output) # Use sync or async based on model capabilities from langgraph._internal._runnable import RunnableCallable @@ -417,12 +413,12 @@ def create_agent( # noqa: PLR0915 if tool_node is not None: graph.add_conditional_edges( "tools", - _make_tools_to_model_edge(tool_node, first_node), + _make_tools_to_model_edge(tool_node, first_node, structured_output_tools), [first_node, END], ) graph.add_conditional_edges( last_node, - _make_model_to_tools_edge(first_node, structured_output_tools), + _make_model_to_tools_edge(first_node, structured_output_tools, tool_node), [first_node, "tools", END], ) elif last_node == "model_request": @@ -481,27 +477,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None: return None +def _fetch_last_ai_and_tool_messages( + messages: list[AnyMessage], +) -> tuple[AIMessage, list[ToolMessage]]: + last_ai_index: int + last_ai_message: AIMessage + + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], AIMessage): + last_ai_index = i + last_ai_message = cast("AIMessage", messages[i]) + break + + tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)] + return last_ai_message, tool_messages + + def _make_model_to_tools_edge( - first_node: str, structured_output_tools: dict[str, OutputToolBinding] -) -> Callable[[AgentState], str | None]: - def model_to_tools(state: AgentState) -> str | None: + first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode +) -> Callable[[AgentState], str | list[Send] | None]: + def model_to_tools(state: AgentState) -> str | list[Send] | None: if jump_to := state.get("jump_to"): return _resolve_jump(jump_to, first_node) - message = state["messages"][-1] + last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) + tool_message_ids = [m.tool_call_id for m in tool_messages] - # Check if this is a ToolMessage from structured output - if so, end - if isinstance(message, ToolMessage) and message.name in structured_output_tools: - return END + pending_tool_calls = [ + c + for c in last_ai_message.tool_calls + if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools + ] - # Check for tool calls - if isinstance(message, AIMessage) and message.tool_calls: - # If all tool calls are for structured output, don't go to tools - non_structured_calls = [ - tc for tc in message.tool_calls if tc["name"] not in structured_output_tools + if pending_tool_calls: + # imo we should not be injecting state, store here, + # this should be done by the tool node itself ideally but this is a consequence + # of using Send w/ tool calls directly which allows more intuitive interrupt behavior + # largely internal so can be fixed later + pending_tool_calls = [ + tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type] + for call in pending_tool_calls ] - if non_structured_calls: - return "tools" + return [Send("tools", [tool_call]) for tool_call in pending_tool_calls] return END @@ -509,17 +526,21 @@ def _make_model_to_tools_edge( def _make_tools_to_model_edge( - tool_node: ToolNode, next_node: str + tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding] ) -> Callable[[AgentState], str | None]: def tools_to_model(state: AgentState) -> str | None: - ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1] + last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) + if all( tool_node.tools_by_name[c["name"]].return_direct - for c in ai_message.tool_calls + for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name ): return END + if any(t.name in structured_output_tools for t in tool_messages): + return END + return next_node return tools_to_model diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 830caa81a07..58138ac5227 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -4,6 +4,7 @@ from unittest.mock import patch from syrupy.assertion import SnapshotAssertion +from pydantic import BaseModel, Field from langchain_core.language_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( @@ -14,19 +15,23 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.tools import tool +from langgraph.types import Command from langchain.agents.middleware_agent import create_agent -from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware +from langchain.agents.middleware.human_in_the_loop import ( + HumanInTheLoopMiddleware, + HumanInTheLoopConfig, + ActionRequest, +) from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState -from langchain_core.tools import BaseTool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver from langgraph.constants import END from langgraph.graph.message import REMOVE_ALL_MESSAGES -from langgraph.prebuilt.interrupt import ActionRequest, HumanInterruptConfig +from langchain.agents.structured_output import ToolStrategy from .messages import _AnyIdHumanMessage, _AnyIdToolMessage from .model import FakeToolCallingModel @@ -355,27 +360,28 @@ def test_create_agent_jump( # Tests for HumanInTheLoopMiddleware def test_human_in_the_loop_middleware_initialization() -> None: """Test HumanInTheLoopMiddleware initialization.""" - tool_configs = { - "test_tool": HumanInterruptConfig( - allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True - ) + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True} + }, + description_prefix="Custom prefix", + ) + + assert middleware.tool_configs == { + "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True} } - - middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs, message_prefix="Custom prefix") - - assert middleware.tool_configs == tool_configs - assert middleware.message_prefix == "Custom prefix" + assert middleware.description_prefix == "Custom prefix" def test_human_in_the_loop_middleware_no_interrupts_needed() -> None: """Test HumanInTheLoopMiddleware when no interrupts are needed.""" - tool_configs = { - "test_tool": HumanInterruptConfig( - allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True - ) - } - middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs) + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True} + } + ) # Test with no messages state: dict[str, Any] = {"messages": []} @@ -397,15 +403,14 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None: assert result is None -def test_human_in_the_loop_middleware_interrupt_responses() -> None: - """Test HumanInTheLoopMiddleware with different interrupt response types.""" - tool_configs = { - "test_tool": HumanInterruptConfig( - allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True - ) - } +def test_human_in_the_loop_middleware_single_tool_accept() -> None: + """Test HumanInTheLoopMiddleware with single tool accept response.""" - middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs) + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True} + } + ) ai_message = AIMessage( content="I'll help you", @@ -413,55 +418,443 @@ def test_human_in_the_loop_middleware_interrupt_responses() -> None: ) state = {"messages": [HumanMessage(content="Hello"), ai_message]} - # Test accept response def mock_accept(requests): return [{"type": "accept", "args": None}] with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept): result = middleware.after_model(state) assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 assert result["messages"][0] == ai_message assert result["messages"][0].tool_calls == ai_message.tool_calls - # Test edit response + +def test_human_in_the_loop_middleware_single_tool_edit() -> None: + """Test HumanInTheLoopMiddleware with single tool edit response.""" + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + def mock_edit(requests): return [ - {"type": "edit", "args": ActionRequest(action="test_tool", args={"input": "edited"})} + { + "type": "edit", + "args": ActionRequest( + action="test_tool", + args={"input": "edited"}, + ), + } ] with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit): result = middleware.after_model(state) assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"} + assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved - # Test ignore response - def mock_ignore(requests): - return [{"type": "ignore", "args": None}] - with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore): - result = middleware.after_model(state) - assert result is not None - assert result["jump_to"] == "__end__" +def test_human_in_the_loop_middleware_single_tool_response() -> None: + """Test HumanInTheLoopMiddleware with single tool response with custom message.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} - # Test response type def mock_response(requests): - return [{"type": "response", "args": "Custom response"}] + return [{"type": "response", "args": "Custom response message"}] with patch( "langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response ): result = middleware.after_model(state) assert result is not None - assert result["jump_to"] == "model" - assert result["messages"][0]["role"] == "tool" - assert result["messages"][0]["content"] == "Custom response" + assert "messages" in result + assert len(result["messages"]) == 1 # Only tool message when no approved tool calls + assert isinstance(result["messages"][0], ToolMessage) + assert result["messages"][0].content == "Custom response message" + assert result["messages"][0].name == "test_tool" + assert result["messages"][0].tool_call_id == "1" + + +def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None: + """Test HumanInTheLoopMiddleware with multiple tools and mixed response types.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True}, + "get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True}, + } + ) + + ai_message = AIMessage( + content="I'll help you with weather", + tool_calls=[ + {"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"}, + {"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"}, + ], + ) + state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]} + + def mock_mixed_responses(requests): + return [ + {"type": "accept", "args": None}, + {"type": "response", "args": "User rejected this tool call"}, + ] + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses + ): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert ( + len(result["messages"]) == 2 + ) # AI message with accepted tool call + tool message for rejected + + # First message should be the AI message with updated tool calls + updated_ai_message = result["messages"][0] + assert len(updated_ai_message.tool_calls) == 1 # Only accepted tool call + assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted + + # Second message should be the tool message for the rejected tool call + tool_message = result["messages"][1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "User rejected this tool call" + assert tool_message.name == "get_temperature" + + +def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None: + """Test HumanInTheLoopMiddleware with multiple tools and edit responses.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True}, + "get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True}, + } + ) + + ai_message = AIMessage( + content="I'll help you with weather", + tool_calls=[ + {"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"}, + {"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"}, + ], + ) + state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]} + + def mock_edit_responses(requests): + return [ + { + "type": "edit", + "args": ActionRequest( + action="get_forecast", + args={"location": "New York"}, + ), + }, + { + "type": "edit", + "args": ActionRequest( + action="get_temperature", + args={"location": "New York"}, + ), + }, + ] + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses + ): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 + + updated_ai_message = result["messages"][0] + assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"} + assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved + assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"} + assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved + + +def test_human_in_the_loop_middleware_edit_with_modified_args() -> None: + """Test HumanInTheLoopMiddleware with edit action that includes modified args.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + def mock_edit_with_args(requests): + return [ + { + "type": "edit", + "args": ActionRequest( + action="test_tool", + args={"input": "modified"}, + ), + } + ] + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=mock_edit_with_args, + ): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 + + # Should have modified args + updated_ai_message = result["messages"][0] + assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"} + assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved + + +def test_human_in_the_loop_middleware_unknown_response_type() -> None: + """Test HumanInTheLoopMiddleware with unknown response type.""" + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} - # Test unknown response type def mock_unknown(requests): return [{"type": "unknown", "args": None}] with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown): - with pytest.raises(ValueError, match="Unknown response type: unknown"): + with pytest.raises( + ValueError, + match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.", + ): + middleware.after_model(state) + + +def test_human_in_the_loop_middleware_disallowed_action() -> None: + """Test HumanInTheLoopMiddleware with action not allowed by tool config.""" + + # edit is not allowed by tool config + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + def mock_disallowed_action(requests): + return [ + { + "type": "edit", + "args": ActionRequest( + action="test_tool", + args={"input": "modified"}, + ), + } + ] + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=mock_disallowed_action, + ): + with pytest.raises( + ValueError, + match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.", + ): + middleware.after_model(state) + + +def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None: + """Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True} + } + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[ + {"name": "auto_tool", "args": {"input": "auto"}, "id": "1"}, + {"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"}, + ], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + def mock_accept(requests): + return [{"type": "accept", "args": None}] + + with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 + + updated_ai_message = result["messages"][0] + # Should have both tools: auto-approved first, then interrupt tool + assert len(updated_ai_message.tool_calls) == 2 + assert updated_ai_message.tool_calls[0]["name"] == "auto_tool" + assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool" + + +def test_human_in_the_loop_middleware_interrupt_request_structure() -> None: + """Test that interrupt requests are structured correctly.""" + + middleware = HumanInTheLoopMiddleware( + tool_configs={ + "test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True} + }, + description_prefix="Custom prefix", + ) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + captured_requests = [] + + def mock_capture_requests(requests): + captured_requests.extend(requests) + return [{"type": "accept", "args": None}] + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests + ): + middleware.after_model(state) + + assert len(captured_requests) == 1 + request = captured_requests[0] + + assert "action_request" in request + assert "config" in request + assert "description" in request + + assert request["action_request"]["action"] == "test_tool" + assert request["action_request"]["args"] == {"input": "test", "location": "SF"} + expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True} + assert request["config"] == expected_config + assert "Custom prefix" in request["description"] + assert "Tool: test_tool" in request["description"] + assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"] + + +def test_human_in_the_loop_middleware_boolean_configs() -> None: + """Test HITL middleware with boolean tool configs.""" + middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True}) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + # Test accept + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value=[{"type": "accept", "args": None}], + ): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 + assert result["messages"][0].tool_calls == ai_message.tool_calls + + # Test edit + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value=[ + { + "type": "edit", + "args": ActionRequest( + action="test_tool", + args={"input": "edited"}, + ), + } + ], + ): + result = middleware.after_model(state) + assert result is not None + assert "messages" in result + assert len(result["messages"]) == 1 + assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"} + + middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False}) + + result = middleware.after_model(state) + # No interruption should occur + assert result is None + + +def test_human_in_the_loop_middleware_sequence_mismatch() -> None: + """Test that sequence mismatch in resume raises an error.""" + middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True}) + + ai_message = AIMessage( + content="I'll help you", + tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}], + ) + state = {"messages": [HumanMessage(content="Hello"), ai_message]} + + # Test with too few responses + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value=[], # No responses for 1 tool call + ): + with pytest.raises( + ValueError, + match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.", + ): + middleware.after_model(state) + + # Test with too many responses + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + return_value=[ + {"type": "accept", "args": None}, + {"type": "accept", "args": None}, + ], # 2 responses for 1 tool call + ): + with pytest.raises( + ValueError, + match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.", + ): middleware.after_model(state) @@ -733,3 +1126,78 @@ def test_modify_model_request() -> None: assert ( result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!" ) + + +def test_tools_to_model_edge_with_structured_and_regular_tool_calls(): + """Test that when there are both structured and regular tool calls, we execute regular and jump to END.""" + + class WeatherResponse(BaseModel): + """Weather response.""" + + temperature: float = Field(description="Temperature in fahrenheit") + condition: str = Field(description="Weather condition") + + @tool + def regular_tool(query: str) -> str: + """A regular tool that returns a string.""" + return f"Regular tool result for: {query}" + + # Create a fake model that returns both structured and regular tool calls + class FakeModelWithBothToolCalls(FakeToolCallingModel): + def __init__(self): + super().__init__() + self.tool_calls = [ + [ + ToolCall( + name="WeatherResponse", + args={"temperature": 72.0, "condition": "sunny"}, + id="structured_call_1", + ), + ToolCall( + name="regular_tool", args={"query": "test query"}, id="regular_call_1" + ), + ] + ] + + # Create agent with both structured output and regular tools + agent = create_agent( + model=FakeModelWithBothToolCalls(), + tools=[regular_tool], + response_format=ToolStrategy(schema=WeatherResponse), + ) + + # Compile and invoke the agent + compiled_agent = agent.compile() + result = compiled_agent.invoke( + {"messages": [HumanMessage("What's the weather and help me with a query?")]} + ) + + # Verify that we have the expected messages: + # 1. Human message + # 2. AI message with both tool calls + # 3. Tool message from structured tool call + # 4. Tool message from regular tool call + + messages = result["messages"] + assert len(messages) >= 4 + + # Check that we have the AI message with both tool calls + ai_message = messages[1] + assert isinstance(ai_message, AIMessage) + assert len(ai_message.tool_calls) == 2 + + # Check that we have a tool message from the regular tool + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] + assert len(tool_messages) >= 1 + + # The regular tool should have been executed + regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None) + assert regular_tool_message is not None + assert "Regular tool result for: test query" in regular_tool_message.content + + # Verify that the structured response is available in the result + assert "response" in result + assert result["response"] is not None + assert hasattr(result["response"], "temperature") + assert result["response"].temperature == 72.0 + assert result["response"].condition == "sunny"