mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 18:12:35 +00:00
feat(langchain): improved HITL patterns (#32996)
# Main changes / new features ## Better support for parallel tool calls 1. Support for multiple tool calls requiring human input 2. Support for combination of tool calls requiring human input + those that are auto-approved 3. Support structured output w/ tool calls requiring human input 4. Support structured output w/ standard tool calls ## Shortcut for allowed actions Adds a shortcut where tool config can be specified as a `bool`, meaning "all actions allowed" ```py HumanInTheLoopMiddleware(tool_configs={"expensive_tool": True}) ``` ## A few design decisions here * We only raise one interrupt w/ all `HumanInterrupt`s, currently we won't be able to execute all tools until all of these are resolved. This isn't super blocking bc we can't re-invoke the model until all tools have finished execution. That being said, if you have a long running auto-approved tool, this could slow things down. ## TODOs * Ideally, we would rename `accept` -> `approve` * Ideally, we would rename `respond` -> `reject` * Docs update (@sydney-runkle to own) * In another PR I'd like to refactor testing to have one file for each prebuilt middleware :) Fast follow to https://github.com/langchain-ai/langchain/pull/32962 which was deemed as too breaking
This commit is contained in:
@@ -1,97 +0,0 @@
|
||||
"""Interrupt types to use with agent inbox like setups."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class HumanInterruptConfig(TypedDict):
|
||||
"""Configuration that defines what actions are allowed for a human interrupt.
|
||||
|
||||
This controls the available interaction options when the graph is paused for human input.
|
||||
|
||||
Attributes:
|
||||
allow_ignore: Whether the human can choose to ignore/skip the current step
|
||||
allow_respond: Whether the human can provide a text response/feedback
|
||||
allow_edit: Whether the human can edit the provided content/state
|
||||
allow_accept: Whether the human can accept/approve the current state
|
||||
"""
|
||||
|
||||
allow_ignore: bool
|
||||
allow_respond: bool
|
||||
allow_edit: bool
|
||||
allow_accept: bool
|
||||
|
||||
|
||||
class ActionRequest(TypedDict):
|
||||
"""Represents a request for human action within the graph execution.
|
||||
|
||||
Contains the action type and any associated arguments needed for the action.
|
||||
|
||||
Attributes:
|
||||
action: The type or name of action being requested (e.g., "Approve XYZ action")
|
||||
args: Key-value pairs of arguments needed for the action
|
||||
"""
|
||||
|
||||
action: str
|
||||
args: dict
|
||||
|
||||
|
||||
class HumanInterrupt(TypedDict):
|
||||
"""Represents an interrupt triggered by the graph that requires human intervention.
|
||||
|
||||
This is passed to the `interrupt` function when execution is paused for human input.
|
||||
|
||||
Attributes:
|
||||
action_request: The specific action being requested from the human
|
||||
config: Configuration defining what actions are allowed
|
||||
description: Optional detailed description of what input is needed
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Extract a tool call from the state and create an interrupt request
|
||||
request = HumanInterrupt(
|
||||
action_request=ActionRequest(
|
||||
action="run_command", # The action being requested
|
||||
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
|
||||
),
|
||||
config=HumanInterruptConfig(
|
||||
allow_ignore=True, # Allow skipping this step
|
||||
allow_respond=True, # Allow text feedback
|
||||
allow_edit=False, # Don't allow editing
|
||||
allow_accept=True, # Allow direct acceptance
|
||||
),
|
||||
description="Please review the command before execution",
|
||||
)
|
||||
# Send the interrupt request and get the response
|
||||
response = interrupt([request])[0]
|
||||
```
|
||||
"""
|
||||
|
||||
action_request: ActionRequest
|
||||
config: HumanInterruptConfig
|
||||
description: str | None
|
||||
|
||||
|
||||
class HumanResponse(TypedDict):
|
||||
"""Human response.
|
||||
|
||||
The response provided by a human to an interrupt,
|
||||
which is returned when graph execution resumes.
|
||||
|
||||
Attributes:
|
||||
type: The type of response:
|
||||
|
||||
- "accept": Approves the current state without changes
|
||||
- "ignore": Skips/ignores the current step
|
||||
- "response": Provides text feedback or instructions
|
||||
- "edit": Modifies the current state/content
|
||||
args: The response payload:
|
||||
|
||||
- None: For ignore/accept actions
|
||||
- str: For text responses
|
||||
- ActionRequest: For edit actions with updated content
|
||||
"""
|
||||
|
||||
type: Literal["accept", "ignore", "response", "edit"]
|
||||
args: None | str | ActionRequest
|
@@ -1,11 +0,0 @@
|
||||
"""Utility functions for middleware."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]:
|
||||
"""Generate tool messages for model behavior correction."""
|
||||
return [
|
||||
{"role": "tool", "content": content, "tool_call_id": tool_call["id"]}
|
||||
for tool_call in tool_calls
|
||||
]
|
@@ -1,19 +1,110 @@
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from langgraph.prebuilt.interrupt import (
|
||||
ActionRequest,
|
||||
HumanInterrupt,
|
||||
HumanInterruptConfig,
|
||||
HumanResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
|
||||
ToolInterruptConfig = dict[str, HumanInterruptConfig]
|
||||
|
||||
class HumanInTheLoopConfig(TypedDict):
|
||||
"""Configuration that defines what actions are allowed for a human interrupt.
|
||||
|
||||
This controls the available interaction options when the graph is paused for human input.
|
||||
"""
|
||||
|
||||
allow_accept: NotRequired[bool]
|
||||
"""Whether the human can approve the current action without changes."""
|
||||
allow_edit: NotRequired[bool]
|
||||
"""Whether the human can approve the current action with edited content."""
|
||||
allow_respond: NotRequired[bool]
|
||||
"""Whether the human can reject the current action with feedback."""
|
||||
|
||||
|
||||
class ActionRequest(TypedDict):
|
||||
"""Represents a request with a name and arguments."""
|
||||
|
||||
action: str
|
||||
"""The type or name of action being requested (e.g., "add_numbers")."""
|
||||
args: dict
|
||||
"""Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2})."""
|
||||
|
||||
|
||||
class HumanInTheLoopRequest(TypedDict):
|
||||
"""Represents an interrupt triggered by the graph that requires human intervention.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Extract a tool call from the state and create an interrupt request
|
||||
request = HumanInterrupt(
|
||||
action_request=ActionRequest(
|
||||
action="run_command", # The action being requested
|
||||
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
|
||||
),
|
||||
config=HumanInTheLoopConfig(
|
||||
allow_accept=True, # Allow approval
|
||||
allow_respond=True, # Allow rejection with feedback
|
||||
allow_edit=False, # Don't allow approval with edits
|
||||
),
|
||||
description="Please review the command before execution",
|
||||
)
|
||||
# Send the interrupt request and get the response
|
||||
response = interrupt([request])[0]
|
||||
```
|
||||
"""
|
||||
|
||||
action_request: ActionRequest
|
||||
"""The specific action being requested from the human."""
|
||||
config: HumanInTheLoopConfig
|
||||
"""Configuration defining what response types are allowed."""
|
||||
description: str | None
|
||||
"""Optional detailed description of what input is needed."""
|
||||
|
||||
|
||||
class AcceptPayload(TypedDict):
|
||||
"""Response when a human approves the action."""
|
||||
|
||||
type: Literal["accept"]
|
||||
"""The type of response when a human approves the action."""
|
||||
|
||||
|
||||
class ResponsePayload(TypedDict):
|
||||
"""Response when a human rejects the action."""
|
||||
|
||||
type: Literal["response"]
|
||||
"""The type of response when a human rejects the action."""
|
||||
|
||||
args: NotRequired[str]
|
||||
"""The message to be sent to the model explaining why the action was rejected."""
|
||||
|
||||
|
||||
class EditPayload(TypedDict):
|
||||
"""Response when a human edits the action."""
|
||||
|
||||
type: Literal["edit"]
|
||||
"""The type of response when a human edits the action."""
|
||||
|
||||
args: ActionRequest
|
||||
"""The action request with the edited content."""
|
||||
|
||||
|
||||
HumanInTheLoopResponse = AcceptPayload | ResponsePayload | EditPayload
|
||||
"""Aggregated response type for all possible human in the loop responses."""
|
||||
|
||||
|
||||
class ToolConfig(TypedDict):
|
||||
"""Configuration for a tool requiring human in the loop."""
|
||||
|
||||
allow_accept: NotRequired[bool]
|
||||
"""Whether the human can approve the current action without changes."""
|
||||
allow_edit: NotRequired[bool]
|
||||
"""Whether the human can approve the current action with edited content."""
|
||||
allow_respond: NotRequired[bool]
|
||||
"""Whether the human can reject the current action with feedback."""
|
||||
description: NotRequired[str]
|
||||
"""The description attached to the request for human input."""
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
@@ -21,18 +112,36 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_configs: ToolInterruptConfig,
|
||||
message_prefix: str = "Tool execution requires approval",
|
||||
tool_configs: dict[str, bool | ToolConfig],
|
||||
*,
|
||||
description_prefix: str = "Tool execution requires approval",
|
||||
) -> None:
|
||||
"""Initialize the human in the loop middleware.
|
||||
|
||||
Args:
|
||||
tool_configs: The tool interrupt configs to use for the middleware.
|
||||
message_prefix: The message prefix to use when constructing interrupt content.
|
||||
tool_configs: Mapping of tool name to allowed actions.
|
||||
If a tool doesn't have an entry, it's auto-approved by default.
|
||||
* `True` indicates all actions are allowed: accept, edit, and respond.
|
||||
* `False` indicates that the tool is auto-approved.
|
||||
* ToolConfig indicates the specific actions allowed for this tool.
|
||||
description_prefix: The prefix to use when constructing action requests.
|
||||
This is used to provide context about the tool call and the action being requested.
|
||||
Not used if a tool has a description in its ToolConfig.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tool_configs = tool_configs
|
||||
self.message_prefix = message_prefix
|
||||
resolved_tool_configs: dict[str, ToolConfig] = {}
|
||||
for tool_name, tool_config in tool_configs.items():
|
||||
if isinstance(tool_config, bool):
|
||||
if tool_config is True:
|
||||
resolved_tool_configs[tool_name] = ToolConfig(
|
||||
allow_accept=True,
|
||||
allow_edit=True,
|
||||
allow_respond=True,
|
||||
)
|
||||
else:
|
||||
resolved_tool_configs[tool_name] = tool_config
|
||||
self.tool_configs = resolved_tool_configs
|
||||
self.description_prefix = description_prefix
|
||||
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||
@@ -40,96 +149,105 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None)
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls = []
|
||||
hitl_tool_calls: list[ToolCall] = []
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
if tool_name in self.tool_configs:
|
||||
interrupt_tool_calls.append(tool_call)
|
||||
else:
|
||||
auto_approved_tool_calls.append(tool_call)
|
||||
for tool_call in last_ai_msg.tool_calls:
|
||||
hitl_tool_calls.append(tool_call) if tool_call[
|
||||
"name"
|
||||
] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not interrupt_tool_calls:
|
||||
if not hitl_tool_calls:
|
||||
return None
|
||||
|
||||
approved_tool_calls = auto_approved_tool_calls.copy()
|
||||
# Process all tool calls that require interrupts
|
||||
approved_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
|
||||
# Right now, we do not support multiple tool calls with interrupts
|
||||
if len(interrupt_tool_calls) > 1:
|
||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
||||
msg = (
|
||||
f"Called the following tools which require interrupts: {tool_names}\n\n"
|
||||
"You may only call ONE tool that requires an interrupt at a time"
|
||||
)
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
||||
"jump_to": "model",
|
||||
}
|
||||
|
||||
# Right now, we do not support interrupting a tool call if other tool calls exist
|
||||
if auto_approved_tool_calls:
|
||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
||||
msg = (
|
||||
f"Called the following tools which require interrupts: {tool_names}. "
|
||||
"You also called other tools that do not require interrupts. "
|
||||
"If you call a tool that requires and interrupt, you may ONLY call that tool."
|
||||
)
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
||||
"jump_to": "model",
|
||||
}
|
||||
|
||||
# Only one tool call will need interrupts
|
||||
tool_call = interrupt_tool_calls[0]
|
||||
# Create interrupt requests for all tools that need approval
|
||||
hitl_requests: list[HumanInTheLoopRequest] = []
|
||||
for tool_call in hitl_tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
tool_config = self.tool_configs[tool_name]
|
||||
config = self.tool_configs[tool_name]
|
||||
description = (
|
||||
config.get("description")
|
||||
or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
)
|
||||
|
||||
request: HumanInterrupt = {
|
||||
request: HumanInTheLoopRequest = {
|
||||
"action_request": ActionRequest(
|
||||
action=tool_name,
|
||||
args=tool_args,
|
||||
),
|
||||
"config": tool_config,
|
||||
"config": config,
|
||||
"description": description,
|
||||
}
|
||||
hitl_requests.append(request)
|
||||
|
||||
responses: list[HumanResponse] = interrupt([request])
|
||||
response = responses[0]
|
||||
responses: list[HumanInTheLoopResponse] = interrupt(hitl_requests)
|
||||
|
||||
if response["type"] == "accept":
|
||||
approved_tool_calls.append(tool_call)
|
||||
elif response["type"] == "edit":
|
||||
edited: ActionRequest = response["args"] # type: ignore[assignment]
|
||||
new_tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": edited["args"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
approved_tool_calls.append(new_tool_call)
|
||||
elif response["type"] == "ignore":
|
||||
return {"jump_to": "__end__"}
|
||||
elif response["type"] == "response":
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": response["args"],
|
||||
}
|
||||
return {"messages": [tool_message], "jump_to": "model"}
|
||||
else:
|
||||
msg = f"Unknown response type: {response['type']}"
|
||||
# Validate that the number of responses matches the number of interrupt tool calls
|
||||
if (responses_len := len(responses)) != (hitl_tool_calls_len := len(hitl_tool_calls)):
|
||||
msg = (
|
||||
f"Number of human responses ({responses_len}) does not match "
|
||||
f"number of hanging tool calls ({hitl_tool_calls_len})."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
last_message.tool_calls = approved_tool_calls
|
||||
for i, response in enumerate(responses):
|
||||
tool_call = hitl_tool_calls[i]
|
||||
config = self.tool_configs[tool_call["name"]]
|
||||
|
||||
return {"messages": [last_message]}
|
||||
if response["type"] == "accept" and config.get("allow_accept"):
|
||||
approved_tool_calls.append(tool_call)
|
||||
elif response["type"] == "edit" and config.get("allow_edit"):
|
||||
edited_action = response["args"]
|
||||
approved_tool_calls.append(
|
||||
ToolCall(
|
||||
type="tool_call",
|
||||
name=edited_action["action"],
|
||||
args=edited_action["args"],
|
||||
id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
elif response["type"] == "response" and config.get("allow_respond"):
|
||||
# Create a tool message with the human's text response
|
||||
content = response.get("args") or (
|
||||
f"User rejected the tool call for `{tool_call['name']}` "
|
||||
f"with id {tool_call['id']}"
|
||||
)
|
||||
tool_message = ToolMessage(
|
||||
content=content,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
status="error",
|
||||
)
|
||||
artificial_tool_messages.append(tool_message)
|
||||
else:
|
||||
allowed_actions = [
|
||||
action
|
||||
for action in ["accept", "edit", "response"]
|
||||
if config.get(f"allow_{'respond' if action == 'response' else action}")
|
||||
]
|
||||
msg = (
|
||||
f"Unexpected human response: {response}. "
|
||||
f"Response action '{response.get('type')}' "
|
||||
f"is not allowed for tool '{tool_call['name']}'. "
|
||||
f"Expected one of {allowed_actions} based on the tool's configuration."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Update the AI message to only include approved tool calls
|
||||
last_ai_msg.tool_calls = approved_tool_calls
|
||||
|
||||
if len(approved_tool_calls) > 0:
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
return {"jump_to": "model", "messages": artificial_tool_messages}
|
||||
|
@@ -2,14 +2,15 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.types import Send
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import TypedDict, TypeVar
|
||||
|
||||
@@ -211,15 +212,13 @@ def create_agent( # noqa: PLR0915
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
||||
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
|
||||
"""Handle model output including structured responses."""
|
||||
# Handle structured output with native strategy
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
if not output.tool_calls and native_output_binding:
|
||||
structured_response = native_output_binding.parse(output)
|
||||
return {"messages": [output], "response": structured_response}
|
||||
if state.get("response") is not None:
|
||||
return {"messages": [output], "response": None}
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tools strategy
|
||||
@@ -297,9 +296,6 @@ def create_agent( # noqa: PLR0915
|
||||
],
|
||||
}
|
||||
|
||||
# Standard response handling
|
||||
if state.get("response") is not None:
|
||||
return {"messages": [output], "response": None}
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
@@ -346,7 +342,7 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
return _handle_model_output(output)
|
||||
|
||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
@@ -373,7 +369,7 @@ def create_agent( # noqa: PLR0915
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
return _handle_model_output(output)
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
@@ -417,12 +413,12 @@ def create_agent( # noqa: PLR0915
|
||||
if tool_node is not None:
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(tool_node, first_node),
|
||||
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
||||
[first_node, END],
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
last_node,
|
||||
_make_model_to_tools_edge(first_node, structured_output_tools),
|
||||
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
|
||||
[first_node, "tools", END],
|
||||
)
|
||||
elif last_node == "model_request":
|
||||
@@ -481,27 +477,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_last_ai_and_tool_messages(
|
||||
messages: list[AnyMessage],
|
||||
) -> tuple[AIMessage, list[ToolMessage]]:
|
||||
last_ai_index: int
|
||||
last_ai_message: AIMessage
|
||||
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], AIMessage):
|
||||
last_ai_index = i
|
||||
last_ai_message = cast("AIMessage", messages[i])
|
||||
break
|
||||
|
||||
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
||||
return last_ai_message, tool_messages
|
||||
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||
) -> Callable[[AgentState], str | None]:
|
||||
def model_to_tools(state: AgentState) -> str | None:
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||
) -> Callable[[AgentState], str | list[Send] | None]:
|
||||
def model_to_tools(state: AgentState) -> str | list[Send] | None:
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
|
||||
message = state["messages"][-1]
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
||||
|
||||
# Check if this is a ToolMessage from structured output - if so, end
|
||||
if isinstance(message, ToolMessage) and message.name in structured_output_tools:
|
||||
return END
|
||||
|
||||
# Check for tool calls
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
# If all tool calls are for structured output, don't go to tools
|
||||
non_structured_calls = [
|
||||
tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
|
||||
pending_tool_calls = [
|
||||
c
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
||||
]
|
||||
if non_structured_calls:
|
||||
return "tools"
|
||||
|
||||
if pending_tool_calls:
|
||||
# imo we should not be injecting state, store here,
|
||||
# this should be done by the tool node itself ideally but this is a consequence
|
||||
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
|
||||
# largely internal so can be fixed later
|
||||
pending_tool_calls = [
|
||||
tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
|
||||
for call in pending_tool_calls
|
||||
]
|
||||
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
||||
|
||||
return END
|
||||
|
||||
@@ -509,17 +526,21 @@ def _make_model_to_tools_edge(
|
||||
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
tool_node: ToolNode, next_node: str
|
||||
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||
) -> Callable[[AgentState], str | None]:
|
||||
def tools_to_model(state: AgentState) -> str | None:
|
||||
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
|
||||
if all(
|
||||
tool_node.tools_by_name[c["name"]].return_direct
|
||||
for c in ai_message.tool_calls
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
):
|
||||
return END
|
||||
|
||||
if any(t.name in structured_output_tools for t in tool_messages):
|
||||
return END
|
||||
|
||||
return next_node
|
||||
|
||||
return tools_to_model
|
||||
|
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -14,19 +15,23 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
HumanInTheLoopConfig,
|
||||
ActionRequest,
|
||||
)
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.prebuilt.interrupt import ActionRequest, HumanInterruptConfig
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
@@ -355,27 +360,28 @@ def test_create_agent_jump(
|
||||
# Tests for HumanInTheLoopMiddleware
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
assert middleware.tool_configs == {
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs, message_prefix="Custom prefix")
|
||||
|
||||
assert middleware.tool_configs == tool_configs
|
||||
assert middleware.message_prefix == "Custom prefix"
|
||||
assert middleware.description_prefix == "Custom prefix"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||
tool_configs = {
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
)
|
||||
}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
# Test with no messages
|
||||
state: dict[str, Any] = {"messages": []}
|
||||
@@ -397,15 +403,14 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with different interrupt response types."""
|
||||
tool_configs = {
|
||||
"test_tool": HumanInterruptConfig(
|
||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
||||
)
|
||||
}
|
||||
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool accept response."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
@@ -413,55 +418,443 @@ def test_human_in_the_loop_middleware_interrupt_responses() -> None:
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept response
|
||||
def mock_accept(requests):
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0] == ai_message
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit response
|
||||
|
||||
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool edit response."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit(requests):
|
||||
return [
|
||||
{"type": "edit", "args": ActionRequest(action="test_tool", args={"input": "edited"})}
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
|
||||
|
||||
# Test ignore response
|
||||
def mock_ignore(requests):
|
||||
return [{"type": "ignore", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "__end__"
|
||||
def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test response type
|
||||
def mock_response(requests):
|
||||
return [{"type": "response", "args": "Custom response"}]
|
||||
return [{"type": "response", "args": "Custom response message"}]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert result["messages"][0]["role"] == "tool"
|
||||
assert result["messages"][0]["content"] == "Custom response"
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1 # Only tool message when no approved tool calls
|
||||
assert isinstance(result["messages"][0], ToolMessage)
|
||||
assert result["messages"][0].content == "Custom response message"
|
||||
assert result["messages"][0].name == "test_tool"
|
||||
assert result["messages"][0].tool_call_id == "1"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_mixed_responses(requests):
|
||||
return [
|
||||
{"type": "accept", "args": None},
|
||||
{"type": "response", "args": "User rejected this tool call"},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert (
|
||||
len(result["messages"]) == 2
|
||||
) # AI message with accepted tool call + tool message for rejected
|
||||
|
||||
# First message should be the AI message with updated tool calls
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert len(updated_ai_message.tool_calls) == 1 # Only accepted tool call
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
|
||||
|
||||
# Second message should be the tool message for the rejected tool call
|
||||
tool_message = result["messages"][1]
|
||||
assert isinstance(tool_message, ToolMessage)
|
||||
assert tool_message.content == "User rejected this tool call"
|
||||
assert tool_message.name == "get_temperature"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you with weather",
|
||||
tool_calls=[
|
||||
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||
|
||||
def mock_edit_responses(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="get_forecast",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="get_temperature",
|
||||
args={"location": "New York"},
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
|
||||
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_edit_with_args(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_edit_with_args,
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
# Should have modified args
|
||||
updated_ai_message = result["messages"][0]
|
||||
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
|
||||
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with unknown response type."""
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test unknown response type
|
||||
def mock_unknown(requests):
|
||||
return [{"type": "unknown", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||
with pytest.raises(ValueError, match="Unknown response type: unknown"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
|
||||
|
||||
# edit is not allowed by tool config
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_disallowed_action(requests):
|
||||
return [
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "modified"},
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
side_effect=mock_disallowed_action,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[
|
||||
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
|
||||
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||
],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
def mock_accept(requests):
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
|
||||
updated_ai_message = result["messages"][0]
|
||||
# Should have both tools: auto-approved first, then interrupt tool
|
||||
assert len(updated_ai_message.tool_calls) == 2
|
||||
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
|
||||
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||
"""Test that interrupt requests are structured correctly."""
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(
|
||||
tool_configs={
|
||||
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
},
|
||||
description_prefix="Custom prefix",
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
captured_requests = []
|
||||
|
||||
def mock_capture_requests(requests):
|
||||
captured_requests.extend(requests)
|
||||
return [{"type": "accept", "args": None}]
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
assert len(captured_requests) == 1
|
||||
request = captured_requests[0]
|
||||
|
||||
assert "action_request" in request
|
||||
assert "config" in request
|
||||
assert "description" in request
|
||||
|
||||
assert request["action_request"]["action"] == "test_tool"
|
||||
assert request["action_request"]["args"] == {"input": "test", "location": "SF"}
|
||||
expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||
assert request["config"] == expected_config
|
||||
assert "Custom prefix" in request["description"]
|
||||
assert "Tool: test_tool" in request["description"]
|
||||
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||
"""Test HITL middleware with boolean tool configs."""
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test accept
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[{"type": "accept", "args": None}],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||
|
||||
# Test edit
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[
|
||||
{
|
||||
"type": "edit",
|
||||
"args": ActionRequest(
|
||||
action="test_tool",
|
||||
args={"input": "edited"},
|
||||
),
|
||||
}
|
||||
],
|
||||
):
|
||||
result = middleware.after_model(state)
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False})
|
||||
|
||||
result = middleware.after_model(state)
|
||||
# No interruption should occur
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||
"""Test that sequence mismatch in resume raises an error."""
|
||||
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||
|
||||
ai_message = AIMessage(
|
||||
content="I'll help you",
|
||||
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||
)
|
||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||
|
||||
# Test with too few responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[], # No responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
# Test with too many responses
|
||||
with patch(
|
||||
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||
return_value=[
|
||||
{"type": "accept", "args": None},
|
||||
{"type": "accept", "args": None},
|
||||
], # 2 responses for 1 tool call
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
|
||||
):
|
||||
middleware.after_model(state)
|
||||
|
||||
|
||||
@@ -733,3 +1126,78 @@ def test_modify_model_request() -> None:
|
||||
assert (
|
||||
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
||||
)
|
||||
|
||||
|
||||
def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
||||
"""Test that when there are both structured and regular tool calls, we execute regular and jump to END."""
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="Temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
@tool
|
||||
def regular_tool(query: str) -> str:
|
||||
"""A regular tool that returns a string."""
|
||||
return f"Regular tool result for: {query}"
|
||||
|
||||
# Create a fake model that returns both structured and regular tool calls
|
||||
class FakeModelWithBothToolCalls(FakeToolCallingModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_calls = [
|
||||
[
|
||||
ToolCall(
|
||||
name="WeatherResponse",
|
||||
args={"temperature": 72.0, "condition": "sunny"},
|
||||
id="structured_call_1",
|
||||
),
|
||||
ToolCall(
|
||||
name="regular_tool", args={"query": "test query"}, id="regular_call_1"
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
# Create agent with both structured output and regular tools
|
||||
agent = create_agent(
|
||||
model=FakeModelWithBothToolCalls(),
|
||||
tools=[regular_tool],
|
||||
response_format=ToolStrategy(schema=WeatherResponse),
|
||||
)
|
||||
|
||||
# Compile and invoke the agent
|
||||
compiled_agent = agent.compile()
|
||||
result = compiled_agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
|
||||
)
|
||||
|
||||
# Verify that we have the expected messages:
|
||||
# 1. Human message
|
||||
# 2. AI message with both tool calls
|
||||
# 3. Tool message from structured tool call
|
||||
# 4. Tool message from regular tool call
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) >= 4
|
||||
|
||||
# Check that we have the AI message with both tool calls
|
||||
ai_message = messages[1]
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
assert len(ai_message.tool_calls) == 2
|
||||
|
||||
# Check that we have a tool message from the regular tool
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) >= 1
|
||||
|
||||
# The regular tool should have been executed
|
||||
regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None)
|
||||
assert regular_tool_message is not None
|
||||
assert "Regular tool result for: test query" in regular_tool_message.content
|
||||
|
||||
# Verify that the structured response is available in the result
|
||||
assert "response" in result
|
||||
assert result["response"] is not None
|
||||
assert hasattr(result["response"], "temperature")
|
||||
assert result["response"].temperature == 72.0
|
||||
assert result["response"].condition == "sunny"
|
||||
|
Reference in New Issue
Block a user