mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 18:12:35 +00:00
feat(langchain): improved HITL patterns (#32996)
# Main changes / new features ## Better support for parallel tool calls 1. Support for multiple tool calls requiring human input 2. Support for combination of tool calls requiring human input + those that are auto-approved 3. Support structured output w/ tool calls requiring human input 4. Support structured output w/ standard tool calls ## Shortcut for allowed actions Adds a shortcut where tool config can be specified as a `bool`, meaning "all actions allowed" ```py HumanInTheLoopMiddleware(tool_configs={"expensive_tool": True}) ``` ## A few design decisions here * We only raise one interrupt w/ all `HumanInterrupt`s, currently we won't be able to execute all tools until all of these are resolved. This isn't super blocking bc we can't re-invoke the model until all tools have finished execution. That being said, if you have a long running auto-approved tool, this could slow things down. ## TODOs * Ideally, we would rename `accept` -> `approve` * Ideally, we would rename `respond` -> `reject` * Docs update (@sydney-runkle to own) * In another PR I'd like to refactor testing to have one file for each prebuilt middleware :) Fast follow to https://github.com/langchain-ai/langchain/pull/32962 which was deemed as too breaking
This commit is contained in:
@@ -1,97 +0,0 @@
|
|||||||
"""Interrupt types to use with agent inbox like setups."""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInterruptConfig(TypedDict):
|
|
||||||
"""Configuration that defines what actions are allowed for a human interrupt.
|
|
||||||
|
|
||||||
This controls the available interaction options when the graph is paused for human input.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
allow_ignore: Whether the human can choose to ignore/skip the current step
|
|
||||||
allow_respond: Whether the human can provide a text response/feedback
|
|
||||||
allow_edit: Whether the human can edit the provided content/state
|
|
||||||
allow_accept: Whether the human can accept/approve the current state
|
|
||||||
"""
|
|
||||||
|
|
||||||
allow_ignore: bool
|
|
||||||
allow_respond: bool
|
|
||||||
allow_edit: bool
|
|
||||||
allow_accept: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ActionRequest(TypedDict):
|
|
||||||
"""Represents a request for human action within the graph execution.
|
|
||||||
|
|
||||||
Contains the action type and any associated arguments needed for the action.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
action: The type or name of action being requested (e.g., "Approve XYZ action")
|
|
||||||
args: Key-value pairs of arguments needed for the action
|
|
||||||
"""
|
|
||||||
|
|
||||||
action: str
|
|
||||||
args: dict
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInterrupt(TypedDict):
|
|
||||||
"""Represents an interrupt triggered by the graph that requires human intervention.
|
|
||||||
|
|
||||||
This is passed to the `interrupt` function when execution is paused for human input.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
action_request: The specific action being requested from the human
|
|
||||||
config: Configuration defining what actions are allowed
|
|
||||||
description: Optional detailed description of what input is needed
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
# Extract a tool call from the state and create an interrupt request
|
|
||||||
request = HumanInterrupt(
|
|
||||||
action_request=ActionRequest(
|
|
||||||
action="run_command", # The action being requested
|
|
||||||
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
|
|
||||||
),
|
|
||||||
config=HumanInterruptConfig(
|
|
||||||
allow_ignore=True, # Allow skipping this step
|
|
||||||
allow_respond=True, # Allow text feedback
|
|
||||||
allow_edit=False, # Don't allow editing
|
|
||||||
allow_accept=True, # Allow direct acceptance
|
|
||||||
),
|
|
||||||
description="Please review the command before execution",
|
|
||||||
)
|
|
||||||
# Send the interrupt request and get the response
|
|
||||||
response = interrupt([request])[0]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
action_request: ActionRequest
|
|
||||||
config: HumanInterruptConfig
|
|
||||||
description: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class HumanResponse(TypedDict):
|
|
||||||
"""Human response.
|
|
||||||
|
|
||||||
The response provided by a human to an interrupt,
|
|
||||||
which is returned when graph execution resumes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
type: The type of response:
|
|
||||||
|
|
||||||
- "accept": Approves the current state without changes
|
|
||||||
- "ignore": Skips/ignores the current step
|
|
||||||
- "response": Provides text feedback or instructions
|
|
||||||
- "edit": Modifies the current state/content
|
|
||||||
args: The response payload:
|
|
||||||
|
|
||||||
- None: For ignore/accept actions
|
|
||||||
- str: For text responses
|
|
||||||
- ActionRequest: For edit actions with updated content
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: Literal["accept", "ignore", "response", "edit"]
|
|
||||||
args: None | str | ActionRequest
|
|
@@ -1,11 +0,0 @@
|
|||||||
"""Utility functions for middleware."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]:
|
|
||||||
"""Generate tool messages for model behavior correction."""
|
|
||||||
return [
|
|
||||||
{"role": "tool", "content": content, "tool_call_id": tool_call["id"]}
|
|
||||||
for tool_call in tool_calls
|
|
||||||
]
|
|
@@ -1,19 +1,110 @@
|
|||||||
"""Human in the loop middleware."""
|
"""Human in the loop middleware."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langgraph.prebuilt.interrupt import (
|
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||||
ActionRequest,
|
|
||||||
HumanInterrupt,
|
|
||||||
HumanInterruptConfig,
|
|
||||||
HumanResponse,
|
|
||||||
)
|
|
||||||
from langgraph.types import interrupt
|
from langgraph.types import interrupt
|
||||||
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||||
|
|
||||||
ToolInterruptConfig = dict[str, HumanInterruptConfig]
|
|
||||||
|
class HumanInTheLoopConfig(TypedDict):
|
||||||
|
"""Configuration that defines what actions are allowed for a human interrupt.
|
||||||
|
|
||||||
|
This controls the available interaction options when the graph is paused for human input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
allow_accept: NotRequired[bool]
|
||||||
|
"""Whether the human can approve the current action without changes."""
|
||||||
|
allow_edit: NotRequired[bool]
|
||||||
|
"""Whether the human can approve the current action with edited content."""
|
||||||
|
allow_respond: NotRequired[bool]
|
||||||
|
"""Whether the human can reject the current action with feedback."""
|
||||||
|
|
||||||
|
|
||||||
|
class ActionRequest(TypedDict):
|
||||||
|
"""Represents a request with a name and arguments."""
|
||||||
|
|
||||||
|
action: str
|
||||||
|
"""The type or name of action being requested (e.g., "add_numbers")."""
|
||||||
|
args: dict
|
||||||
|
"""Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2})."""
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInTheLoopRequest(TypedDict):
|
||||||
|
"""Represents an interrupt triggered by the graph that requires human intervention.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# Extract a tool call from the state and create an interrupt request
|
||||||
|
request = HumanInterrupt(
|
||||||
|
action_request=ActionRequest(
|
||||||
|
action="run_command", # The action being requested
|
||||||
|
args={"command": "ls", "args": ["-l"]}, # Arguments for the action
|
||||||
|
),
|
||||||
|
config=HumanInTheLoopConfig(
|
||||||
|
allow_accept=True, # Allow approval
|
||||||
|
allow_respond=True, # Allow rejection with feedback
|
||||||
|
allow_edit=False, # Don't allow approval with edits
|
||||||
|
),
|
||||||
|
description="Please review the command before execution",
|
||||||
|
)
|
||||||
|
# Send the interrupt request and get the response
|
||||||
|
response = interrupt([request])[0]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
action_request: ActionRequest
|
||||||
|
"""The specific action being requested from the human."""
|
||||||
|
config: HumanInTheLoopConfig
|
||||||
|
"""Configuration defining what response types are allowed."""
|
||||||
|
description: str | None
|
||||||
|
"""Optional detailed description of what input is needed."""
|
||||||
|
|
||||||
|
|
||||||
|
class AcceptPayload(TypedDict):
|
||||||
|
"""Response when a human approves the action."""
|
||||||
|
|
||||||
|
type: Literal["accept"]
|
||||||
|
"""The type of response when a human approves the action."""
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsePayload(TypedDict):
|
||||||
|
"""Response when a human rejects the action."""
|
||||||
|
|
||||||
|
type: Literal["response"]
|
||||||
|
"""The type of response when a human rejects the action."""
|
||||||
|
|
||||||
|
args: NotRequired[str]
|
||||||
|
"""The message to be sent to the model explaining why the action was rejected."""
|
||||||
|
|
||||||
|
|
||||||
|
class EditPayload(TypedDict):
|
||||||
|
"""Response when a human edits the action."""
|
||||||
|
|
||||||
|
type: Literal["edit"]
|
||||||
|
"""The type of response when a human edits the action."""
|
||||||
|
|
||||||
|
args: ActionRequest
|
||||||
|
"""The action request with the edited content."""
|
||||||
|
|
||||||
|
|
||||||
|
HumanInTheLoopResponse = AcceptPayload | ResponsePayload | EditPayload
|
||||||
|
"""Aggregated response type for all possible human in the loop responses."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(TypedDict):
|
||||||
|
"""Configuration for a tool requiring human in the loop."""
|
||||||
|
|
||||||
|
allow_accept: NotRequired[bool]
|
||||||
|
"""Whether the human can approve the current action without changes."""
|
||||||
|
allow_edit: NotRequired[bool]
|
||||||
|
"""Whether the human can approve the current action with edited content."""
|
||||||
|
allow_respond: NotRequired[bool]
|
||||||
|
"""Whether the human can reject the current action with feedback."""
|
||||||
|
description: NotRequired[str]
|
||||||
|
"""The description attached to the request for human input."""
|
||||||
|
|
||||||
|
|
||||||
class HumanInTheLoopMiddleware(AgentMiddleware):
|
class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||||
@@ -21,18 +112,36 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tool_configs: ToolInterruptConfig,
|
tool_configs: dict[str, bool | ToolConfig],
|
||||||
message_prefix: str = "Tool execution requires approval",
|
*,
|
||||||
|
description_prefix: str = "Tool execution requires approval",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the human in the loop middleware.
|
"""Initialize the human in the loop middleware.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_configs: The tool interrupt configs to use for the middleware.
|
tool_configs: Mapping of tool name to allowed actions.
|
||||||
message_prefix: The message prefix to use when constructing interrupt content.
|
If a tool doesn't have an entry, it's auto-approved by default.
|
||||||
|
* `True` indicates all actions are allowed: accept, edit, and respond.
|
||||||
|
* `False` indicates that the tool is auto-approved.
|
||||||
|
* ToolConfig indicates the specific actions allowed for this tool.
|
||||||
|
description_prefix: The prefix to use when constructing action requests.
|
||||||
|
This is used to provide context about the tool call and the action being requested.
|
||||||
|
Not used if a tool has a description in its ToolConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tool_configs = tool_configs
|
resolved_tool_configs: dict[str, ToolConfig] = {}
|
||||||
self.message_prefix = message_prefix
|
for tool_name, tool_config in tool_configs.items():
|
||||||
|
if isinstance(tool_config, bool):
|
||||||
|
if tool_config is True:
|
||||||
|
resolved_tool_configs[tool_name] = ToolConfig(
|
||||||
|
allow_accept=True,
|
||||||
|
allow_edit=True,
|
||||||
|
allow_respond=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resolved_tool_configs[tool_name] = tool_config
|
||||||
|
self.tool_configs = resolved_tool_configs
|
||||||
|
self.description_prefix = description_prefix
|
||||||
|
|
||||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||||
@@ -40,96 +149,105 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|||||||
if not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
last_message = messages[-1]
|
last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None)
|
||||||
|
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Separate tool calls that need interrupts from those that don't
|
# Separate tool calls that need interrupts from those that don't
|
||||||
interrupt_tool_calls = []
|
hitl_tool_calls: list[ToolCall] = []
|
||||||
auto_approved_tool_calls = []
|
auto_approved_tool_calls = []
|
||||||
|
|
||||||
for tool_call in last_message.tool_calls:
|
for tool_call in last_ai_msg.tool_calls:
|
||||||
tool_name = tool_call["name"]
|
hitl_tool_calls.append(tool_call) if tool_call[
|
||||||
if tool_name in self.tool_configs:
|
"name"
|
||||||
interrupt_tool_calls.append(tool_call)
|
] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
|
||||||
else:
|
|
||||||
auto_approved_tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
# If no interrupts needed, return early
|
# If no interrupts needed, return early
|
||||||
if not interrupt_tool_calls:
|
if not hitl_tool_calls:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
approved_tool_calls = auto_approved_tool_calls.copy()
|
# Process all tool calls that require interrupts
|
||||||
|
approved_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
|
||||||
|
artificial_tool_messages: list[ToolMessage] = []
|
||||||
|
|
||||||
# Right now, we do not support multiple tool calls with interrupts
|
# Create interrupt requests for all tools that need approval
|
||||||
if len(interrupt_tool_calls) > 1:
|
hitl_requests: list[HumanInTheLoopRequest] = []
|
||||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
for tool_call in hitl_tool_calls:
|
||||||
msg = (
|
|
||||||
f"Called the following tools which require interrupts: {tool_names}\n\n"
|
|
||||||
"You may only call ONE tool that requires an interrupt at a time"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
|
||||||
"jump_to": "model",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Right now, we do not support interrupting a tool call if other tool calls exist
|
|
||||||
if auto_approved_tool_calls:
|
|
||||||
tool_names = [t["name"] for t in interrupt_tool_calls]
|
|
||||||
msg = (
|
|
||||||
f"Called the following tools which require interrupts: {tool_names}. "
|
|
||||||
"You also called other tools that do not require interrupts. "
|
|
||||||
"If you call a tool that requires and interrupt, you may ONLY call that tool."
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
|
||||||
"jump_to": "model",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only one tool call will need interrupts
|
|
||||||
tool_call = interrupt_tool_calls[0]
|
|
||||||
tool_name = tool_call["name"]
|
tool_name = tool_call["name"]
|
||||||
tool_args = tool_call["args"]
|
tool_args = tool_call["args"]
|
||||||
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
config = self.tool_configs[tool_name]
|
||||||
tool_config = self.tool_configs[tool_name]
|
description = (
|
||||||
|
config.get("description")
|
||||||
|
or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||||
|
)
|
||||||
|
|
||||||
request: HumanInterrupt = {
|
request: HumanInTheLoopRequest = {
|
||||||
"action_request": ActionRequest(
|
"action_request": ActionRequest(
|
||||||
action=tool_name,
|
action=tool_name,
|
||||||
args=tool_args,
|
args=tool_args,
|
||||||
),
|
),
|
||||||
"config": tool_config,
|
"config": config,
|
||||||
"description": description,
|
"description": description,
|
||||||
}
|
}
|
||||||
|
hitl_requests.append(request)
|
||||||
|
|
||||||
responses: list[HumanResponse] = interrupt([request])
|
responses: list[HumanInTheLoopResponse] = interrupt(hitl_requests)
|
||||||
response = responses[0]
|
|
||||||
|
|
||||||
if response["type"] == "accept":
|
# Validate that the number of responses matches the number of interrupt tool calls
|
||||||
approved_tool_calls.append(tool_call)
|
if (responses_len := len(responses)) != (hitl_tool_calls_len := len(hitl_tool_calls)):
|
||||||
elif response["type"] == "edit":
|
msg = (
|
||||||
edited: ActionRequest = response["args"] # type: ignore[assignment]
|
f"Number of human responses ({responses_len}) does not match "
|
||||||
new_tool_call = {
|
f"number of hanging tool calls ({hitl_tool_calls_len})."
|
||||||
"type": "tool_call",
|
)
|
||||||
"name": tool_call["name"],
|
|
||||||
"args": edited["args"],
|
|
||||||
"id": tool_call["id"],
|
|
||||||
}
|
|
||||||
approved_tool_calls.append(new_tool_call)
|
|
||||||
elif response["type"] == "ignore":
|
|
||||||
return {"jump_to": "__end__"}
|
|
||||||
elif response["type"] == "response":
|
|
||||||
tool_message = {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call["id"],
|
|
||||||
"content": response["args"],
|
|
||||||
}
|
|
||||||
return {"messages": [tool_message], "jump_to": "model"}
|
|
||||||
else:
|
|
||||||
msg = f"Unknown response type: {response['type']}"
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
last_message.tool_calls = approved_tool_calls
|
for i, response in enumerate(responses):
|
||||||
|
tool_call = hitl_tool_calls[i]
|
||||||
|
config = self.tool_configs[tool_call["name"]]
|
||||||
|
|
||||||
return {"messages": [last_message]}
|
if response["type"] == "accept" and config.get("allow_accept"):
|
||||||
|
approved_tool_calls.append(tool_call)
|
||||||
|
elif response["type"] == "edit" and config.get("allow_edit"):
|
||||||
|
edited_action = response["args"]
|
||||||
|
approved_tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
type="tool_call",
|
||||||
|
name=edited_action["action"],
|
||||||
|
args=edited_action["args"],
|
||||||
|
id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif response["type"] == "response" and config.get("allow_respond"):
|
||||||
|
# Create a tool message with the human's text response
|
||||||
|
content = response.get("args") or (
|
||||||
|
f"User rejected the tool call for `{tool_call['name']}` "
|
||||||
|
f"with id {tool_call['id']}"
|
||||||
|
)
|
||||||
|
tool_message = ToolMessage(
|
||||||
|
content=content,
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
status="error",
|
||||||
|
)
|
||||||
|
artificial_tool_messages.append(tool_message)
|
||||||
|
else:
|
||||||
|
allowed_actions = [
|
||||||
|
action
|
||||||
|
for action in ["accept", "edit", "response"]
|
||||||
|
if config.get(f"allow_{'respond' if action == 'response' else action}")
|
||||||
|
]
|
||||||
|
msg = (
|
||||||
|
f"Unexpected human response: {response}. "
|
||||||
|
f"Response action '{response.get('type')}' "
|
||||||
|
f"is not allowed for tool '{tool_call['name']}'. "
|
||||||
|
f"Expected one of {allowed_actions} based on the tool's configuration."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Update the AI message to only include approved tool calls
|
||||||
|
last_ai_msg.tool_calls = approved_tool_calls
|
||||||
|
|
||||||
|
if len(approved_tool_calls) > 0:
|
||||||
|
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||||
|
|
||||||
|
return {"jump_to": "model", "messages": artificial_tool_messages}
|
||||||
|
@@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph.state import StateGraph
|
from langgraph.graph.state import StateGraph
|
||||||
|
from langgraph.types import Send
|
||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
from typing_extensions import TypedDict, TypeVar
|
from typing_extensions import TypedDict, TypeVar
|
||||||
|
|
||||||
@@ -211,15 +212,13 @@ def create_agent( # noqa: PLR0915
|
|||||||
context_schema=context_schema,
|
context_schema=context_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
|
||||||
"""Handle model output including structured responses."""
|
"""Handle model output including structured responses."""
|
||||||
# Handle structured output with native strategy
|
# Handle structured output with native strategy
|
||||||
if isinstance(response_format, ProviderStrategy):
|
if isinstance(response_format, ProviderStrategy):
|
||||||
if not output.tool_calls and native_output_binding:
|
if not output.tool_calls and native_output_binding:
|
||||||
structured_response = native_output_binding.parse(output)
|
structured_response = native_output_binding.parse(output)
|
||||||
return {"messages": [output], "response": structured_response}
|
return {"messages": [output], "response": structured_response}
|
||||||
if state.get("response") is not None:
|
|
||||||
return {"messages": [output], "response": None}
|
|
||||||
return {"messages": [output]}
|
return {"messages": [output]}
|
||||||
|
|
||||||
# Handle structured output with tools strategy
|
# Handle structured output with tools strategy
|
||||||
@@ -297,9 +296,6 @@ def create_agent( # noqa: PLR0915
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Standard response handling
|
|
||||||
if state.get("response") is not None:
|
|
||||||
return {"messages": [output], "response": None}
|
|
||||||
return {"messages": [output]}
|
return {"messages": [output]}
|
||||||
|
|
||||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||||
@@ -346,7 +342,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
messages = [SystemMessage(request.system_prompt), *messages]
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
|
|
||||||
output = model_.invoke(messages)
|
output = model_.invoke(messages)
|
||||||
return _handle_model_output(state, output)
|
return _handle_model_output(output)
|
||||||
|
|
||||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Async model request handler with sequential middleware processing."""
|
"""Async model request handler with sequential middleware processing."""
|
||||||
@@ -373,7 +369,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
messages = [SystemMessage(request.system_prompt), *messages]
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
|
|
||||||
output = await model_.ainvoke(messages)
|
output = await model_.ainvoke(messages)
|
||||||
return _handle_model_output(state, output)
|
return _handle_model_output(output)
|
||||||
|
|
||||||
# Use sync or async based on model capabilities
|
# Use sync or async based on model capabilities
|
||||||
from langgraph._internal._runnable import RunnableCallable
|
from langgraph._internal._runnable import RunnableCallable
|
||||||
@@ -417,12 +413,12 @@ def create_agent( # noqa: PLR0915
|
|||||||
if tool_node is not None:
|
if tool_node is not None:
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"tools",
|
"tools",
|
||||||
_make_tools_to_model_edge(tool_node, first_node),
|
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
||||||
[first_node, END],
|
[first_node, END],
|
||||||
)
|
)
|
||||||
graph.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
last_node,
|
last_node,
|
||||||
_make_model_to_tools_edge(first_node, structured_output_tools),
|
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
|
||||||
[first_node, "tools", END],
|
[first_node, "tools", END],
|
||||||
)
|
)
|
||||||
elif last_node == "model_request":
|
elif last_node == "model_request":
|
||||||
@@ -481,27 +477,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_last_ai_and_tool_messages(
|
||||||
|
messages: list[AnyMessage],
|
||||||
|
) -> tuple[AIMessage, list[ToolMessage]]:
|
||||||
|
last_ai_index: int
|
||||||
|
last_ai_message: AIMessage
|
||||||
|
|
||||||
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
if isinstance(messages[i], AIMessage):
|
||||||
|
last_ai_index = i
|
||||||
|
last_ai_message = cast("AIMessage", messages[i])
|
||||||
|
break
|
||||||
|
|
||||||
|
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
||||||
|
return last_ai_message, tool_messages
|
||||||
|
|
||||||
|
|
||||||
def _make_model_to_tools_edge(
|
def _make_model_to_tools_edge(
|
||||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||||
) -> Callable[[AgentState], str | None]:
|
) -> Callable[[AgentState], str | list[Send] | None]:
|
||||||
def model_to_tools(state: AgentState) -> str | None:
|
def model_to_tools(state: AgentState) -> str | list[Send] | None:
|
||||||
if jump_to := state.get("jump_to"):
|
if jump_to := state.get("jump_to"):
|
||||||
return _resolve_jump(jump_to, first_node)
|
return _resolve_jump(jump_to, first_node)
|
||||||
|
|
||||||
message = state["messages"][-1]
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||||
|
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
||||||
|
|
||||||
# Check if this is a ToolMessage from structured output - if so, end
|
pending_tool_calls = [
|
||||||
if isinstance(message, ToolMessage) and message.name in structured_output_tools:
|
c
|
||||||
return END
|
for c in last_ai_message.tool_calls
|
||||||
|
if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
|
||||||
# Check for tool calls
|
|
||||||
if isinstance(message, AIMessage) and message.tool_calls:
|
|
||||||
# If all tool calls are for structured output, don't go to tools
|
|
||||||
non_structured_calls = [
|
|
||||||
tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
|
|
||||||
]
|
]
|
||||||
if non_structured_calls:
|
|
||||||
return "tools"
|
if pending_tool_calls:
|
||||||
|
# imo we should not be injecting state, store here,
|
||||||
|
# this should be done by the tool node itself ideally but this is a consequence
|
||||||
|
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
|
||||||
|
# largely internal so can be fixed later
|
||||||
|
pending_tool_calls = [
|
||||||
|
tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
|
||||||
|
for call in pending_tool_calls
|
||||||
|
]
|
||||||
|
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
||||||
|
|
||||||
return END
|
return END
|
||||||
|
|
||||||
@@ -509,17 +526,21 @@ def _make_model_to_tools_edge(
|
|||||||
|
|
||||||
|
|
||||||
def _make_tools_to_model_edge(
|
def _make_tools_to_model_edge(
|
||||||
tool_node: ToolNode, next_node: str
|
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||||
) -> Callable[[AgentState], str | None]:
|
) -> Callable[[AgentState], str | None]:
|
||||||
def tools_to_model(state: AgentState) -> str | None:
|
def tools_to_model(state: AgentState) -> str | None:
|
||||||
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||||
|
|
||||||
if all(
|
if all(
|
||||||
tool_node.tools_by_name[c["name"]].return_direct
|
tool_node.tools_by_name[c["name"]].return_direct
|
||||||
for c in ai_message.tool_calls
|
for c in last_ai_message.tool_calls
|
||||||
if c["name"] in tool_node.tools_by_name
|
if c["name"] in tool_node.tools_by_name
|
||||||
):
|
):
|
||||||
return END
|
return END
|
||||||
|
|
||||||
|
if any(t.name in structured_output_tools for t in tool_messages):
|
||||||
|
return END
|
||||||
|
|
||||||
return next_node
|
return next_node
|
||||||
|
|
||||||
return tools_to_model
|
return tools_to_model
|
||||||
|
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@@ -14,19 +15,23 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
from langchain.agents.middleware_agent import create_agent
|
from langchain.agents.middleware_agent import create_agent
|
||||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
from langchain.agents.middleware.human_in_the_loop import (
|
||||||
|
HumanInTheLoopMiddleware,
|
||||||
|
HumanInTheLoopConfig,
|
||||||
|
ActionRequest,
|
||||||
|
)
|
||||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
|
|
||||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.constants import END
|
from langgraph.constants import END
|
||||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||||
from langgraph.prebuilt.interrupt import ActionRequest, HumanInterruptConfig
|
from langchain.agents.structured_output import ToolStrategy
|
||||||
|
|
||||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||||
from .model import FakeToolCallingModel
|
from .model import FakeToolCallingModel
|
||||||
@@ -355,27 +360,28 @@ def test_create_agent_jump(
|
|||||||
# Tests for HumanInTheLoopMiddleware
|
# Tests for HumanInTheLoopMiddleware
|
||||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||||
"""Test HumanInTheLoopMiddleware initialization."""
|
"""Test HumanInTheLoopMiddleware initialization."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
tool_configs={
|
tool_configs={
|
||||||
"test_tool": HumanInterruptConfig(
|
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
},
|
||||||
|
description_prefix="Custom prefix",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert middleware.tool_configs == {
|
||||||
|
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
}
|
}
|
||||||
|
assert middleware.description_prefix == "Custom prefix"
|
||||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs, message_prefix="Custom prefix")
|
|
||||||
|
|
||||||
assert middleware.tool_configs == tool_configs
|
|
||||||
assert middleware.message_prefix == "Custom prefix"
|
|
||||||
|
|
||||||
|
|
||||||
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
||||||
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
"""Test HumanInTheLoopMiddleware when no interrupts are needed."""
|
||||||
tool_configs = {
|
|
||||||
"test_tool": HumanInterruptConfig(
|
|
||||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Test with no messages
|
# Test with no messages
|
||||||
state: dict[str, Any] = {"messages": []}
|
state: dict[str, Any] = {"messages": []}
|
||||||
@@ -397,15 +403,14 @@ def test_human_in_the_loop_middleware_no_interrupts_needed() -> None:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_human_in_the_loop_middleware_interrupt_responses() -> None:
|
def test_human_in_the_loop_middleware_single_tool_accept() -> None:
|
||||||
"""Test HumanInTheLoopMiddleware with different interrupt response types."""
|
"""Test HumanInTheLoopMiddleware with single tool accept response."""
|
||||||
tool_configs = {
|
|
||||||
"test_tool": HumanInterruptConfig(
|
|
||||||
allow_ignore=True, allow_respond=True, allow_edit=True, allow_accept=True
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
middleware = HumanInTheLoopMiddleware(tool_configs=tool_configs)
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
ai_message = AIMessage(
|
ai_message = AIMessage(
|
||||||
content="I'll help you",
|
content="I'll help you",
|
||||||
@@ -413,55 +418,443 @@ def test_human_in_the_loop_middleware_interrupt_responses() -> None:
|
|||||||
)
|
)
|
||||||
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
# Test accept response
|
|
||||||
def mock_accept(requests):
|
def mock_accept(requests):
|
||||||
return [{"type": "accept", "args": None}]
|
return [{"type": "accept", "args": None}]
|
||||||
|
|
||||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||||
result = middleware.after_model(state)
|
result = middleware.after_model(state)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
assert result["messages"][0] == ai_message
|
assert result["messages"][0] == ai_message
|
||||||
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||||
|
|
||||||
# Test edit response
|
|
||||||
|
def test_human_in_the_loop_middleware_single_tool_edit() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with single tool edit response."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
def mock_edit(requests):
|
def mock_edit(requests):
|
||||||
return [
|
return [
|
||||||
{"type": "edit", "args": ActionRequest(action="test_tool", args={"input": "edited"})}
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="test_tool",
|
||||||
|
args={"input": "edited"},
|
||||||
|
),
|
||||||
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit):
|
||||||
result = middleware.after_model(state)
|
result = middleware.after_model(state)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||||
|
assert result["messages"][0].tool_calls[0]["id"] == "1" # ID should be preserved
|
||||||
|
|
||||||
# Test ignore response
|
|
||||||
def mock_ignore(requests):
|
|
||||||
return [{"type": "ignore", "args": None}]
|
|
||||||
|
|
||||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_ignore):
|
def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||||
result = middleware.after_model(state)
|
"""Test HumanInTheLoopMiddleware with single tool response with custom message."""
|
||||||
assert result is not None
|
|
||||||
assert result["jump_to"] == "__end__"
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
# Test response type
|
|
||||||
def mock_response(requests):
|
def mock_response(requests):
|
||||||
return [{"type": "response", "args": "Custom response"}]
|
return [{"type": "response", "args": "Custom response message"}]
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_response
|
||||||
):
|
):
|
||||||
result = middleware.after_model(state)
|
result = middleware.after_model(state)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "model"
|
assert "messages" in result
|
||||||
assert result["messages"][0]["role"] == "tool"
|
assert len(result["messages"]) == 1 # Only tool message when no approved tool calls
|
||||||
assert result["messages"][0]["content"] == "Custom response"
|
assert isinstance(result["messages"][0], ToolMessage)
|
||||||
|
assert result["messages"][0].content == "Custom response message"
|
||||||
|
assert result["messages"][0].name == "test_tool"
|
||||||
|
assert result["messages"][0].tool_call_id == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||||
|
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you with weather",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||||
|
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||||
|
|
||||||
|
def mock_mixed_responses(requests):
|
||||||
|
return [
|
||||||
|
{"type": "accept", "args": None},
|
||||||
|
{"type": "response", "args": "User rejected this tool call"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_mixed_responses
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert (
|
||||||
|
len(result["messages"]) == 2
|
||||||
|
) # AI message with accepted tool call + tool message for rejected
|
||||||
|
|
||||||
|
# First message should be the AI message with updated tool calls
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
assert len(updated_ai_message.tool_calls) == 1 # Only accepted tool call
|
||||||
|
assert updated_ai_message.tool_calls[0]["name"] == "get_forecast" # Accepted
|
||||||
|
|
||||||
|
# Second message should be the tool message for the rejected tool call
|
||||||
|
tool_message = result["messages"][1]
|
||||||
|
assert isinstance(tool_message, ToolMessage)
|
||||||
|
assert tool_message.content == "User rejected this tool call"
|
||||||
|
assert tool_message.name == "get_temperature"
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_multiple_tools_edit_responses() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with multiple tools and edit responses."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"get_forecast": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||||
|
"get_temperature": {"allow_accept": True, "allow_edit": True, "allow_respond": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you with weather",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "get_forecast", "args": {"location": "San Francisco"}, "id": "1"},
|
||||||
|
{"name": "get_temperature", "args": {"location": "San Francisco"}, "id": "2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="What's the weather?"), ai_message]}
|
||||||
|
|
||||||
|
def mock_edit_responses(requests):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="get_forecast",
|
||||||
|
args={"location": "New York"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="get_temperature",
|
||||||
|
args={"location": "New York"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_edit_responses
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
assert updated_ai_message.tool_calls[0]["args"] == {"location": "New York"}
|
||||||
|
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||||
|
assert updated_ai_message.tool_calls[1]["args"] == {"location": "New York"}
|
||||||
|
assert updated_ai_message.tool_calls[1]["id"] == "2" # ID preserved
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_edit_with_modified_args() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with edit action that includes modified args."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
def mock_edit_with_args(requests):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="test_tool",
|
||||||
|
args={"input": "modified"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
side_effect=mock_edit_with_args,
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
|
||||||
|
# Should have modified args
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
assert updated_ai_message.tool_calls[0]["args"] == {"input": "modified"}
|
||||||
|
assert updated_ai_message.tool_calls[0]["id"] == "1" # ID preserved
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_unknown_response_type() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with unknown response type."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
# Test unknown response type
|
|
||||||
def mock_unknown(requests):
|
def mock_unknown(requests):
|
||||||
return [{"type": "unknown", "args": None}]
|
return [{"type": "unknown", "args": None}]
|
||||||
|
|
||||||
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_unknown):
|
||||||
with pytest.raises(ValueError, match="Unknown response type: unknown"):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"Unexpected human response: {'type': 'unknown', 'args': None}. Response action 'unknown' is not allowed for tool 'test_tool'. Expected one of \['accept', 'edit', 'response'\] based on the tool's configuration.",
|
||||||
|
):
|
||||||
|
middleware.after_model(state)
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_disallowed_action() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with action not allowed by tool config."""
|
||||||
|
|
||||||
|
# edit is not allowed by tool config
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_respond": True, "allow_edit": False, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
def mock_disallowed_action(requests):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="test_tool",
|
||||||
|
args={"input": "modified"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
side_effect=mock_disallowed_action,
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"Unexpected human response: {'type': 'edit', 'args': {'action': 'test_tool', 'args': {'input': 'modified'}}}. Response action 'edit' is not allowed for tool 'test_tool'. Expected one of \['accept', 'response'\] based on the tool's configuration.",
|
||||||
|
):
|
||||||
|
middleware.after_model(state)
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_mixed_auto_approved_and_interrupt() -> None:
|
||||||
|
"""Test HumanInTheLoopMiddleware with mix of auto-approved and interrupt tools."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"interrupt_tool": {"allow_respond": True, "allow_edit": True, "allow_accept": True}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[
|
||||||
|
{"name": "auto_tool", "args": {"input": "auto"}, "id": "1"},
|
||||||
|
{"name": "interrupt_tool", "args": {"input": "interrupt"}, "id": "2"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
def mock_accept(requests):
|
||||||
|
return [{"type": "accept", "args": None}]
|
||||||
|
|
||||||
|
with patch("langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_accept):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
|
||||||
|
updated_ai_message = result["messages"][0]
|
||||||
|
# Should have both tools: auto-approved first, then interrupt tool
|
||||||
|
assert len(updated_ai_message.tool_calls) == 2
|
||||||
|
assert updated_ai_message.tool_calls[0]["name"] == "auto_tool"
|
||||||
|
assert updated_ai_message.tool_calls[1]["name"] == "interrupt_tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_interrupt_request_structure() -> None:
|
||||||
|
"""Test that interrupt requests are structured correctly."""
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(
|
||||||
|
tool_configs={
|
||||||
|
"test_tool": {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
|
},
|
||||||
|
description_prefix="Custom prefix",
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test", "location": "SF"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
captured_requests = []
|
||||||
|
|
||||||
|
def mock_capture_requests(requests):
|
||||||
|
captured_requests.extend(requests)
|
||||||
|
return [{"type": "accept", "args": None}]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt", side_effect=mock_capture_requests
|
||||||
|
):
|
||||||
|
middleware.after_model(state)
|
||||||
|
|
||||||
|
assert len(captured_requests) == 1
|
||||||
|
request = captured_requests[0]
|
||||||
|
|
||||||
|
assert "action_request" in request
|
||||||
|
assert "config" in request
|
||||||
|
assert "description" in request
|
||||||
|
|
||||||
|
assert request["action_request"]["action"] == "test_tool"
|
||||||
|
assert request["action_request"]["args"] == {"input": "test", "location": "SF"}
|
||||||
|
expected_config = {"allow_accept": True, "allow_edit": True, "allow_respond": True}
|
||||||
|
assert request["config"] == expected_config
|
||||||
|
assert "Custom prefix" in request["description"]
|
||||||
|
assert "Tool: test_tool" in request["description"]
|
||||||
|
assert "Args: {'input': 'test', 'location': 'SF'}" in request["description"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_boolean_configs() -> None:
|
||||||
|
"""Test HITL middleware with boolean tool configs."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
# Test accept
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value=[{"type": "accept", "args": None}],
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
assert result["messages"][0].tool_calls == ai_message.tool_calls
|
||||||
|
|
||||||
|
# Test edit
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value=[
|
||||||
|
{
|
||||||
|
"type": "edit",
|
||||||
|
"args": ActionRequest(
|
||||||
|
action="test_tool",
|
||||||
|
args={"input": "edited"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
):
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
assert result is not None
|
||||||
|
assert "messages" in result
|
||||||
|
assert len(result["messages"]) == 1
|
||||||
|
assert result["messages"][0].tool_calls[0]["args"] == {"input": "edited"}
|
||||||
|
|
||||||
|
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": False})
|
||||||
|
|
||||||
|
result = middleware.after_model(state)
|
||||||
|
# No interruption should occur
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_in_the_loop_middleware_sequence_mismatch() -> None:
|
||||||
|
"""Test that sequence mismatch in resume raises an error."""
|
||||||
|
middleware = HumanInTheLoopMiddleware(tool_configs={"test_tool": True})
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content="I'll help you",
|
||||||
|
tool_calls=[{"name": "test_tool", "args": {"input": "test"}, "id": "1"}],
|
||||||
|
)
|
||||||
|
state = {"messages": [HumanMessage(content="Hello"), ai_message]}
|
||||||
|
|
||||||
|
# Test with too few responses
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value=[], # No responses for 1 tool call
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"Number of human responses \(0\) does not match number of hanging tool calls \(1\)\.",
|
||||||
|
):
|
||||||
|
middleware.after_model(state)
|
||||||
|
|
||||||
|
# Test with too many responses
|
||||||
|
with patch(
|
||||||
|
"langchain.agents.middleware.human_in_the_loop.interrupt",
|
||||||
|
return_value=[
|
||||||
|
{"type": "accept", "args": None},
|
||||||
|
{"type": "accept", "args": None},
|
||||||
|
], # 2 responses for 1 tool call
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"Number of human responses \(2\) does not match number of hanging tool calls \(1\)\.",
|
||||||
|
):
|
||||||
middleware.after_model(state)
|
middleware.after_model(state)
|
||||||
|
|
||||||
|
|
||||||
@@ -733,3 +1126,78 @@ def test_modify_model_request() -> None:
|
|||||||
assert (
|
assert (
|
||||||
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
||||||
|
"""Test that when there are both structured and regular tool calls, we execute regular and jump to END."""
|
||||||
|
|
||||||
|
class WeatherResponse(BaseModel):
|
||||||
|
"""Weather response."""
|
||||||
|
|
||||||
|
temperature: float = Field(description="Temperature in fahrenheit")
|
||||||
|
condition: str = Field(description="Weather condition")
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def regular_tool(query: str) -> str:
|
||||||
|
"""A regular tool that returns a string."""
|
||||||
|
return f"Regular tool result for: {query}"
|
||||||
|
|
||||||
|
# Create a fake model that returns both structured and regular tool calls
|
||||||
|
class FakeModelWithBothToolCalls(FakeToolCallingModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tool_calls = [
|
||||||
|
[
|
||||||
|
ToolCall(
|
||||||
|
name="WeatherResponse",
|
||||||
|
args={"temperature": 72.0, "condition": "sunny"},
|
||||||
|
id="structured_call_1",
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
name="regular_tool", args={"query": "test query"}, id="regular_call_1"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create agent with both structured output and regular tools
|
||||||
|
agent = create_agent(
|
||||||
|
model=FakeModelWithBothToolCalls(),
|
||||||
|
tools=[regular_tool],
|
||||||
|
response_format=ToolStrategy(schema=WeatherResponse),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compile and invoke the agent
|
||||||
|
compiled_agent = agent.compile()
|
||||||
|
result = compiled_agent.invoke(
|
||||||
|
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that we have the expected messages:
|
||||||
|
# 1. Human message
|
||||||
|
# 2. AI message with both tool calls
|
||||||
|
# 3. Tool message from structured tool call
|
||||||
|
# 4. Tool message from regular tool call
|
||||||
|
|
||||||
|
messages = result["messages"]
|
||||||
|
assert len(messages) >= 4
|
||||||
|
|
||||||
|
# Check that we have the AI message with both tool calls
|
||||||
|
ai_message = messages[1]
|
||||||
|
assert isinstance(ai_message, AIMessage)
|
||||||
|
assert len(ai_message.tool_calls) == 2
|
||||||
|
|
||||||
|
# Check that we have a tool message from the regular tool
|
||||||
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||||
|
assert len(tool_messages) >= 1
|
||||||
|
|
||||||
|
# The regular tool should have been executed
|
||||||
|
regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None)
|
||||||
|
assert regular_tool_message is not None
|
||||||
|
assert "Regular tool result for: test query" in regular_tool_message.content
|
||||||
|
|
||||||
|
# Verify that the structured response is available in the result
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] is not None
|
||||||
|
assert hasattr(result["response"], "temperature")
|
||||||
|
assert result["response"].temperature == 72.0
|
||||||
|
assert result["response"].condition == "sunny"
|
||||||
|
Reference in New Issue
Block a user