mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Prevent dangling tool calls from breaking react agents
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langchain_core.messages import AIMessage, RemoveMessage, ToolCall, ToolMessage
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@@ -269,6 +270,42 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""Before the agent runs, handle dangling tool calls from the most recent AIMessage."""
|
||||
messages = state["messages"]
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
patched_messages = []
|
||||
# Iterate over the messages and add any dangling tool calls
|
||||
for i, msg in enumerate(messages):
|
||||
patched_messages.append(msg)
|
||||
if msg.type == "ai" and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
corresponding_tool_msg = next(
|
||||
(
|
||||
msg
|
||||
for msg in messages[i:]
|
||||
if msg.type == "tool" and msg.tool_call_id == tool_call["id"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if corresponding_tool_msg is None:
|
||||
# We have a dangling tool call which needs a ToolMessage
|
||||
tool_msg = (
|
||||
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
patched_messages.append(
|
||||
ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
|
||||
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES), *patched_messages]}
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
|
||||
messages = state["messages"]
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
ToolCall,
|
||||
RemoveMessage,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class TestHumanInTheLoopMiddlewareBeforeModel:
|
||||
"""Test HumanInTheLoopMiddleware before_model behavior."""
|
||||
|
||||
def test_first_message(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 3
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1].type == "system"
|
||||
assert state_update["messages"][1].content == "You are a helpful assistant."
|
||||
assert state_update["messages"][2].type == "human"
|
||||
assert state_update["messages"][2].content == "Hello, how are you?"
|
||||
assert state_update["messages"][2].id == "2"
|
||||
|
||||
def test_missing_tool_call(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="4"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 6
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1] == input_messages[0]
|
||||
assert state_update["messages"][2] == input_messages[1]
|
||||
assert state_update["messages"][3] == input_messages[2]
|
||||
assert state_update["messages"][4].type == "tool"
|
||||
assert state_update["messages"][4].tool_call_id == "123"
|
||||
assert state_update["messages"][4].name == "get_events_for_days"
|
||||
assert state_update["messages"][5] == input_messages[3]
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 5
|
||||
assert updated_messages[0] == input_messages[0]
|
||||
assert updated_messages[1] == input_messages[1]
|
||||
assert updated_messages[2] == input_messages[2]
|
||||
assert updated_messages[3].type == "tool"
|
||||
assert updated_messages[3].tool_call_id == "123"
|
||||
assert updated_messages[3].name == "get_events_for_days"
|
||||
assert updated_messages[4] == input_messages[3]
|
||||
|
||||
def test_no_missing_tool_calls(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
ToolMessage(content="I have no events for that date.", tool_call_id="123", id="4"),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="5"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 6
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1:] == input_messages
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 5
|
||||
assert updated_messages == input_messages
|
||||
|
||||
def test_two_missing_tool_calls(self) -> None:
|
||||
input_messages = [
|
||||
SystemMessage(content="You are a helpful assistant.", id="1"),
|
||||
HumanMessage(content="Hello, how are you?", id="2"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="123", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="4"),
|
||||
AIMessage(
|
||||
content="I'm doing well, thank you!",
|
||||
tool_calls=[
|
||||
ToolCall(id="456", name="get_events_for_days", args={"date_str": "2025-01-01"})
|
||||
],
|
||||
id="5",
|
||||
),
|
||||
HumanMessage(content="What is the weather in Tokyo?", id="6"),
|
||||
]
|
||||
middleware = HumanInTheLoopMiddleware(interrupt_on={})
|
||||
state_update = middleware.before_agent({"messages": input_messages}, None)
|
||||
assert state_update is not None
|
||||
assert len(state_update["messages"]) == 9
|
||||
assert state_update["messages"][0].type == "remove"
|
||||
assert state_update["messages"][1] == input_messages[0]
|
||||
assert state_update["messages"][2] == input_messages[1]
|
||||
assert state_update["messages"][3] == input_messages[2]
|
||||
assert state_update["messages"][4].type == "tool"
|
||||
assert state_update["messages"][4].tool_call_id == "123"
|
||||
assert state_update["messages"][4].name == "get_events_for_days"
|
||||
assert state_update["messages"][5] == input_messages[3]
|
||||
assert state_update["messages"][6] == input_messages[4]
|
||||
assert state_update["messages"][7].type == "tool"
|
||||
assert state_update["messages"][7].tool_call_id == "456"
|
||||
assert state_update["messages"][7].name == "get_events_for_days"
|
||||
assert state_update["messages"][8] == input_messages[5]
|
||||
updated_messages = add_messages(input_messages, state_update["messages"])
|
||||
assert len(updated_messages) == 8
|
||||
assert updated_messages[0] == input_messages[0]
|
||||
assert updated_messages[1] == input_messages[1]
|
||||
assert updated_messages[2] == input_messages[2]
|
||||
assert updated_messages[3].type == "tool"
|
||||
assert updated_messages[3].tool_call_id == "123"
|
||||
assert updated_messages[3].name == "get_events_for_days"
|
||||
assert updated_messages[4] == input_messages[3]
|
||||
assert updated_messages[5] == input_messages[4]
|
||||
assert updated_messages[6].type == "tool"
|
||||
assert updated_messages[6].tool_call_id == "456"
|
||||
assert updated_messages[6].name == "get_events_for_days"
|
||||
assert updated_messages[7] == input_messages[5]
|
||||
Reference in New Issue
Block a user