mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
fix(langchain): conditional edge from tools to end (#33520)
This commit is contained in:
@@ -1225,6 +1225,15 @@ def create_agent( # noqa: PLR0915
|
||||
graph.add_edge(START, entry_node)
|
||||
# add conditional edges only if tools exist
|
||||
if tool_node is not None:
|
||||
# Only include exit_node in destinations if any tool has return_direct=True
|
||||
# or if there are structured output tools
|
||||
tools_to_model_destinations = [loop_entry_node]
|
||||
if (
|
||||
any(tool.return_direct for tool in tool_node.tools_by_name.values())
|
||||
or structured_output_tools
|
||||
):
|
||||
tools_to_model_destinations.append(exit_node)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(
|
||||
@@ -1233,7 +1242,7 @@ def create_agent( # noqa: PLR0915
|
||||
structured_output_tools=structured_output_tools,
|
||||
end_destination=exit_node,
|
||||
),
|
||||
[loop_entry_node, exit_node],
|
||||
tools_to_model_destinations,
|
||||
)
|
||||
|
||||
# base destinations are tools and exit_node
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> NoopTwo\2eafter_agent;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
@@ -343,7 +342,6 @@
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
@@ -376,7 +374,6 @@
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
@@ -409,7 +406,6 @@
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
@@ -442,7 +438,6 @@
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
@@ -475,7 +470,6 @@
|
||||
__start__ --> NoopSeven\2ebefore_model;
|
||||
model --> NoopEight\2eafter_model;
|
||||
tools -.-> NoopSeven\2ebefore_model;
|
||||
tools -.-> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
@@ -497,7 +491,6 @@
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> __end__;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_mixed_tools
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> __end__;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_agent_graph_with_return_direct_tool
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> __end__;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_agent_graph_without_return_direct_tools
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Tests for return_direct tool graph structure."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_agent_graph_without_return_direct_tools(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that graph WITHOUT return_direct tools does NOT have edge from tools to end."""
|
||||
|
||||
@tool
|
||||
def normal_tool(input_string: str) -> str:
|
||||
"""A normal tool without return_direct."""
|
||||
return input_string
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[normal_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# The mermaid diagram should NOT include an edge from tools to __end__
|
||||
# when no tools have return_direct=True
|
||||
mermaid_diagram = agent.get_graph().draw_mermaid()
|
||||
assert mermaid_diagram == snapshot
|
||||
|
||||
|
||||
def test_agent_graph_with_return_direct_tool(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that graph WITH return_direct tools has correct edge from tools to end."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def return_direct_tool(input_string: str) -> str:
|
||||
"""A tool with return_direct=True."""
|
||||
return input_string
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[return_direct_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# The mermaid diagram SHOULD include an edge from tools to __end__
|
||||
# when at least one tool has return_direct=True
|
||||
mermaid_diagram = agent.get_graph().draw_mermaid()
|
||||
assert mermaid_diagram == snapshot
|
||||
|
||||
|
||||
def test_agent_graph_with_mixed_tools(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that graph with mixed tools (some return_direct, some not) has correct edges."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def return_direct_tool(input_string: str) -> str:
|
||||
"""A tool with return_direct=True."""
|
||||
return input_string
|
||||
|
||||
@tool
|
||||
def normal_tool(input_string: str) -> str:
|
||||
"""A normal tool without return_direct."""
|
||||
return input_string
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[return_direct_tool, normal_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# The mermaid diagram SHOULD include an edge from tools to __end__
|
||||
# because at least one tool has return_direct=True
|
||||
mermaid_diagram = agent.get_graph().draw_mermaid()
|
||||
assert mermaid_diagram == snapshot
|
||||
Reference in New Issue
Block a user