fix(langchain): conditional edge from tools to end (#33520)

This commit is contained in:
Sydney Runkle
2025-10-16 11:56:45 -04:00
committed by GitHub
parent c9018f81ec
commit e10d99b728
4 changed files with 152 additions and 8 deletions

View File

@@ -1225,6 +1225,15 @@ def create_agent( # noqa: PLR0915
graph.add_edge(START, entry_node)
# add conditional edges only if tools exist
if tool_node is not None:
# Only include exit_node in destinations if any tool has return_direct=True
# or if there are structured output tools
tools_to_model_destinations = [loop_entry_node]
if (
any(tool.return_direct for tool in tool_node.tools_by_name.values())
or structured_output_tools
):
tools_to_model_destinations.append(exit_node)
graph.add_conditional_edges(
"tools",
_make_tools_to_model_edge(
@@ -1233,7 +1242,7 @@ def create_agent( # noqa: PLR0915
structured_output_tools=structured_output_tools,
end_destination=exit_node,
),
[loop_entry_node, exit_node],
tools_to_model_destinations,
)
# base destinations are tools and exit_node

View File

@@ -20,7 +20,6 @@
__start__ --> NoopZero\2ebefore_agent;
model -.-> NoopTwo\2eafter_agent;
model -.-> tools;
tools -.-> NoopTwo\2eafter_agent;
tools -.-> model;
NoopOne\2eafter_agent --> __end__;
classDef default fill:#f2f0ff,line-height:1.2
@@ -343,7 +342,6 @@
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
tools -.-> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
@@ -376,7 +374,6 @@
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
tools -.-> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
@@ -409,7 +406,6 @@
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
tools -.-> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
@@ -442,7 +438,6 @@
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
tools -.-> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
@@ -475,7 +470,6 @@
__start__ --> NoopSeven\2ebefore_model;
model --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model;
tools -.-> __end__;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
@@ -497,7 +491,6 @@
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> __end__;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0

View File

@@ -0,0 +1,69 @@
# serializer version: 1
# name: test_agent_graph_with_mixed_tools
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> __end__;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_agent_graph_with_return_direct_tool
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> __end__;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_agent_graph_without_return_direct_tools
'''
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
model(model)
tools(tools)
__end__([<p>__end__</p>]):::last
__start__ --> model;
model -.-> __end__;
model -.-> tools;
tools -.-> model;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---

View File

@@ -0,0 +1,73 @@
"""Tests for return_direct tool graph structure."""
from langchain_core.tools import tool
from syrupy.assertion import SnapshotAssertion
from langchain.agents.factory import create_agent
from .model import FakeToolCallingModel
def test_agent_graph_without_return_direct_tools(snapshot: SnapshotAssertion) -> None:
"""Test that graph WITHOUT return_direct tools does NOT have edge from tools to end."""
@tool
def normal_tool(input_string: str) -> str:
"""A normal tool without return_direct."""
return input_string
agent = create_agent(
model=FakeToolCallingModel(),
tools=[normal_tool],
system_prompt="You are a helpful assistant.",
)
# The mermaid diagram should NOT include an edge from tools to __end__
# when no tools have return_direct=True
mermaid_diagram = agent.get_graph().draw_mermaid()
assert mermaid_diagram == snapshot
def test_agent_graph_with_return_direct_tool(snapshot: SnapshotAssertion) -> None:
"""Test that graph WITH return_direct tools has correct edge from tools to end."""
@tool(return_direct=True)
def return_direct_tool(input_string: str) -> str:
"""A tool with return_direct=True."""
return input_string
agent = create_agent(
model=FakeToolCallingModel(),
tools=[return_direct_tool],
system_prompt="You are a helpful assistant.",
)
# The mermaid diagram SHOULD include an edge from tools to __end__
# when at least one tool has return_direct=True
mermaid_diagram = agent.get_graph().draw_mermaid()
assert mermaid_diagram == snapshot
def test_agent_graph_with_mixed_tools(snapshot: SnapshotAssertion) -> None:
"""Test that graph with mixed tools (some return_direct, some not) has correct edges."""
@tool(return_direct=True)
def return_direct_tool(input_string: str) -> str:
"""A tool with return_direct=True."""
return input_string
@tool
def normal_tool(input_string: str) -> str:
"""A normal tool without return_direct."""
return input_string
agent = create_agent(
model=FakeToolCallingModel(),
tools=[return_direct_tool, normal_tool],
system_prompt="You are a helpful assistant.",
)
# The mermaid diagram SHOULD include an edge from tools to __end__
# because at least one tool has return_direct=True
mermaid_diagram = agent.get_graph().draw_mermaid()
assert mermaid_diagram == snapshot