diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index c35fdff884d..7cf13d42c4c 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -138,16 +138,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware): if not messages: return None - last_message = messages[-1] - - if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None) + if not last_ai_msg or not last_ai_msg.tool_calls: return None # Separate tool calls that need interrupts from those that don't interrupt_tool_calls: dict[str, ToolCall] = {} auto_approved_tool_calls = [] - for tool_call in last_message.tool_calls: + for tool_call in last_ai_msg.tool_calls: tool_name = tool_call["name"] if tool_name in self.tool_configs: # fix: id should not be typed as Optional on `langchain_core.messages.tool.ToolCall` @@ -259,9 +258,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware): ) raise ValueError(msg) - last_message.tool_calls = [*approved_tool_calls, *rejected_tool_calls] + last_ai_msg.tool_calls = [*approved_tool_calls, *rejected_tool_calls] if len(approved_tool_calls) > 0: - return {"messages": [last_message, *artificial_tool_messages]} + return {"messages": [last_ai_msg, *artificial_tool_messages]} return {"jump_to": "model", "messages": artificial_tool_messages} diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 495375c1d72..1437dbce36e 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence from typing import Any, cast from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, SystemMessage, ToolMessage +from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langgraph.constants import END, START @@ -219,8 +219,6 @@ def create_agent( # noqa: PLR0915 if not output.tool_calls and native_output_binding: structured_response = native_output_binding.parse(output) return {"messages": [output], "response": structured_response} - if state.get("response") is not None: - return {"messages": [output], "response": None} return {"messages": [output]} # Handle structured output with tools strategy @@ -418,7 +416,7 @@ def create_agent( # noqa: PLR0915 if tool_node is not None: graph.add_conditional_edges( "tools", - _make_tools_to_model_edge(tool_node, first_node), + _make_tools_to_model_edge(tool_node, first_node, structured_output_tools), [first_node, END], ) graph.add_conditional_edges( @@ -482,6 +480,22 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None: return None +def _fetch_last_ai_and_tool_messages( + messages: list[AnyMessage], +) -> tuple[AIMessage, list[ToolMessage]]: + last_ai_index: int + last_ai_message: AIMessage + + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], AIMessage): + last_ai_index = i + last_ai_message = cast("AIMessage", messages[i]) + break + + tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)] + return last_ai_message, tool_messages + + def _make_model_to_tools_edge( first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode ) -> Callable[[AgentState], str | list[Send] | None]: @@ -489,31 +503,7 @@ def _make_model_to_tools_edge( if jump_to := state.get("jump_to"): return _resolve_jump(jump_to, first_node) - last_message = state["messages"][-1] - - # Check if this is a ToolMessage from structured output - if so, end - # interesting, should we be auto ending here? should we execute other tools? - if isinstance(last_message, ToolMessage) and last_message.name in structured_output_tools: - return END - - # Find the last AI message and all tool messages since said AI message - last_ai_index = None - last_ai_message: AIMessage - for i in range(len(state["messages"]) - 1, -1, -1): - if isinstance(state["messages"][i], AIMessage): - last_ai_index = i - last_ai_message = cast("AIMessage", state["messages"][i]) - break - - tool_messages = ( - [ - m.tool_call_id - for m in state["messages"][last_ai_index + 1 :] - if isinstance(m, ToolMessage) - ] - if last_ai_index is not None - else [] - ) + last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) pending_tool_calls = [ c @@ -538,17 +528,21 @@ def _make_model_to_tools_edge( def _make_tools_to_model_edge( - tool_node: ToolNode, next_node: str + tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding] ) -> Callable[[AgentState], str | None]: def tools_to_model(state: AgentState) -> str | None: - ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1] + last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"]) + if all( tool_node.tools_by_name[c["name"]].return_direct - for c in ai_message.tool_calls + for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name ): return END + if any(t.name in structured_output_tools for t in tool_messages): + return END + return next_node return tools_to_model diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 297a3587192..b6ee655fdde 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -4,6 +4,7 @@ from unittest.mock import patch from syrupy.assertion import SnapshotAssertion +from pydantic import BaseModel, Field from langchain_core.language_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( @@ -14,6 +15,7 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.tools import tool +from langgraph.types import Command from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware.human_in_the_loop import ( @@ -28,7 +30,7 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver from langgraph.constants import END from langgraph.graph.message import REMOVE_ALL_MESSAGES -from langgraph.prebuilt.interrupt import ActionRequest +from langchain.agents.structured_output import ToolStrategy from .messages import _AnyIdHumanMessage, _AnyIdToolMessage from .model import FakeToolCallingModel @@ -1275,3 +1277,78 @@ def test_modify_model_request() -> None: assert ( result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!" ) + + +def test_tools_to_model_edge_with_structured_and_regular_tool_calls(): + """Test that when there are both structured and regular tool calls, we execute regular and jump to END.""" + + class WeatherResponse(BaseModel): + """Weather response.""" + + temperature: float = Field(description="Temperature in fahrenheit") + condition: str = Field(description="Weather condition") + + @tool + def regular_tool(query: str) -> str: + """A regular tool that returns a string.""" + return f"Regular tool result for: {query}" + + # Create a fake model that returns both structured and regular tool calls + class FakeModelWithBothToolCalls(FakeToolCallingModel): + def __init__(self): + super().__init__() + self.tool_calls = [ + [ + ToolCall( + name="WeatherResponse", + args={"temperature": 72.0, "condition": "sunny"}, + id="structured_call_1", + ), + ToolCall( + name="regular_tool", args={"query": "test query"}, id="regular_call_1" + ), + ] + ] + + # Create agent with both structured output and regular tools + agent = create_agent( + model=FakeModelWithBothToolCalls(), + tools=[regular_tool], + response_format=ToolStrategy(schema=WeatherResponse), + ) + + # Compile and invoke the agent + compiled_agent = agent.compile() + result = compiled_agent.invoke( + {"messages": [HumanMessage("What's the weather and help me with a query?")]} + ) + + # Verify that we have the expected messages: + # 1. Human message + # 2. AI message with both tool calls + # 3. Tool message from structured tool call + # 4. Tool message from regular tool call + + messages = result["messages"] + assert len(messages) >= 4 + + # Check that we have the AI message with both tool calls + ai_message = messages[1] + assert isinstance(ai_message, AIMessage) + assert len(ai_message.tool_calls) == 2 + + # Check that we have a tool message from the regular tool + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] + assert len(tool_messages) >= 1 + + # The regular tool should have been executed + regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None) + assert regular_tool_message is not None + assert "Regular tool result for: test query" in regular_tool_message.content + + # Verify that the structured response is available in the result + assert "response" in result + assert result["response"] is not None + assert hasattr(result["response"], "temperature") + assert result["response"].temperature == 72.0 + assert result["response"].condition == "sunny"