From c599ba47d5fa9df90b5cd74deb6bc0aca1ff11e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Panella?= Date: Tue, 4 Mar 2025 12:27:49 -0600 Subject: [PATCH] core(mermaid): fix error when 3+ subgraph levels (#29970) --- .../langchain_core/runnables/graph_mermaid.py | 3 + .../runnables/__snapshots__/test_graph.ambr | 31 ++++++++ .../tests/unit_tests/runnables/test_graph.py | 73 +++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index a331b1b0582..4693558cc0d 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -145,6 +145,9 @@ def draw_mermaid( for nested_prefix in edge_groups: if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: continue + # only go to first level subgraphs + if ":" in nested_prefix[len(prefix) + 1 :]: + continue add_subgraph(edge_groups[nested_prefix], nested_prefix) if prefix and not self_loop: diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 5a3b7126d99..4b4568571b3 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -26,6 +26,37 @@ ''' # --- +# name: test_triple_nested_subgraph_mermaid[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + parent_1(parent_1) + child_child_1_grandchild_1(grandchild_1) + child_child_1_grandchild_1_greatgrandchild(greatgrandchild) + child_child_1_grandchild_2(grandchild_2
__interrupt = before) + child_child_2(child_2) + parent_2(parent_2) + __end__([

__end__

]):::last + __start__ --> parent_1; + child_child_2 --> parent_2; + parent_1 --> child_child_1_grandchild_1; + parent_2 --> __end__; + subgraph child + child_child_1_grandchild_2 --> child_child_2; + subgraph child_1 + child_child_1_grandchild_1_greatgrandchild --> child_child_1_grandchild_2; + subgraph grandchild_1 + child_child_1_grandchild_1 --> child_child_1_grandchild_1_greatgrandchild; + end + end + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_graph_mermaid_duplicate_nodes[mermaid] ''' graph TD; diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 58e5749a0d5..c2f7ef9b7dc 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -375,6 +375,79 @@ def test_double_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: assert graph.draw_mermaid() == snapshot(name="mermaid") +def test_triple_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: + empty_data = BaseModel + nodes = { + "__start__": Node( + id="__start__", name="__start__", data=empty_data, metadata=None + ), + "parent_1": Node( + id="parent_1", name="parent_1", data=empty_data, metadata=None + ), + "child:child_1:grandchild_1": Node( + id="child:child_1:grandchild_1", + name="grandchild_1", + data=empty_data, + metadata=None, + ), + "child:child_1:grandchild_1:greatgrandchild": Node( + id="child:child_1:grandchild_1:greatgrandchild", + name="greatgrandchild", + data=empty_data, + metadata=None, + ), + "child:child_1:grandchild_2": Node( + id="child:child_1:grandchild_2", + name="grandchild_2", + data=empty_data, + metadata={"__interrupt": "before"}, + ), + "child:child_2": Node( + id="child:child_2", name="child_2", data=empty_data, metadata=None + ), + "parent_2": Node( + id="parent_2", name="parent_2", data=empty_data, metadata=None + ), + "__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None), + } + edges = [ + Edge( + source="child:child_1:grandchild_1", + target="child:child_1:grandchild_1:greatgrandchild", + data=None, + conditional=False, + ), + Edge( + source="child:child_1:grandchild_1:greatgrandchild", + target="child:child_1:grandchild_2", + data=None, + conditional=False, + ), + Edge( + source="child:child_1:grandchild_2", + target="child:child_2", + data=None, + conditional=False, + ), + Edge(source="__start__", target="parent_1", data=None, conditional=False), + Edge( + source="child:child_2", + target="parent_2", + data=None, + conditional=False, + ), + Edge( + source="parent_1", + target="child:child_1:grandchild_1", + data=None, + conditional=False, + ), + Edge(source="parent_2", target="__end__", data=None, conditional=False), + ] + graph = Graph(nodes, edges) + assert graph.draw_mermaid() == snapshot(name="mermaid") + + def test_runnable_get_graph_with_invalid_input_type() -> None: """Test that error isn't raised when getting graph with invalid input type."""