From 46d344c33d5d04f09849ed20b45fc5e1f5975596 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Thu, 22 Aug 2024 19:08:49 -0400 Subject: [PATCH] core[patch]: support drawing nested subgraphs in draw_mermaid (#25581) Previously the code was able to only handle a single level of nesting for subgraphs in mermaid. This change adds support for arbitrary nesting of subgraphs. --- .../langchain_core/runnables/graph_mermaid.py | 109 +++++++++------ .../runnables/__snapshots__/test_graph.ambr | 57 ++++++++ .../tests/unit_tests/runnables/test_graph.py | 132 +++++++++++++++++- 3 files changed, 258 insertions(+), 40 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index cef9092cebb..f8c184fd53a 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -78,47 +78,78 @@ def draw_mermaid( ) mermaid_graph += f"\t{node_label}\n" - subgraph = "" - # Add edges to the graph + # Group edges by their common prefixes + edge_groups: Dict[str, List[Edge]] = {} for edge in edges: - src_prefix = edge.source.split(":")[0] if ":" in edge.source else None - tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None - # exit subgraph if source or target is not in the same subgraph - if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix): - mermaid_graph += "\tend\n" - subgraph = "" - # enter subgraph if source and target are in the same subgraph - if not subgraph and src_prefix and src_prefix == tgt_prefix: - mermaid_graph += f"\tsubgraph {src_prefix}\n" - subgraph = src_prefix - - source, target = edge.source, edge.target - - # Add BR every wrap_label_n_words words - if edge.data is not None: - edge_data = edge.data - words = str(edge_data).split() # Split the string into words - # Group words into chunks of wrap_label_n_words size - if len(words) > wrap_label_n_words: - edge_data = " 
 ".join( - " ".join(words[i : i + wrap_label_n_words]) - for i in range(0, len(words), wrap_label_n_words) - ) - if edge.conditional: - edge_label = f" -.  {edge_data}  .-> " - else: - edge_label = f" --  {edge_data}  --> " - else: - if edge.conditional: - edge_label = " -.-> " - else: - edge_label = " --> " - mermaid_graph += ( - f"\t{_escape_node_label(source)}{edge_label}" - f"{_escape_node_label(target)};\n" + src_parts = edge.source.split(":") + tgt_parts = edge.target.split(":") + common_prefix = ":".join( + src for src, tgt in zip(src_parts, tgt_parts) if src == tgt ) - if subgraph: - mermaid_graph += "end\n" + edge_groups.setdefault(common_prefix, []).append(edge) + + seen_subgraphs = set() + + def add_subgraph(edges: List[Edge], prefix: str) -> None: + nonlocal mermaid_graph + self_loop = len(edges) == 1 and edges[0].source == edges[0].target + if prefix and not self_loop: + subgraph = prefix.split(":")[-1] + if subgraph in seen_subgraphs: + raise ValueError( + f"Found duplicate subgraph '{subgraph}' -- this likely means that " + "you're reusing a subgraph node with the same name. " + "Please adjust your graph to have subgraph nodes with unique names." + ) + + seen_subgraphs.add(subgraph) + mermaid_graph += f"\tsubgraph {subgraph}\n" + + for edge in edges: + source, target = edge.source, edge.target + + # Add BR every wrap_label_n_words words + if edge.data is not None: + edge_data = edge.data + words = str(edge_data).split() # Split the string into words + # Group words into chunks of wrap_label_n_words size + if len(words) > wrap_label_n_words: + edge_data = " 
 ".join( + " ".join(words[i : i + wrap_label_n_words]) + for i in range(0, len(words), wrap_label_n_words) + ) + if edge.conditional: + edge_label = f" -.  {edge_data}  .-> " + else: + edge_label = f" --  {edge_data}  --> " + else: + if edge.conditional: + edge_label = " -.-> " + else: + edge_label = " --> " + + mermaid_graph += ( + f"\t{_escape_node_label(source)}{edge_label}" + f"{_escape_node_label(target)};\n" + ) + + # Recursively add nested subgraphs + for nested_prefix in edge_groups.keys(): + if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: + continue + add_subgraph(edge_groups[nested_prefix], nested_prefix) + + if prefix and not self_loop: + mermaid_graph += "\tend\n" + + # Start with the top-level edges (no common prefix) + add_subgraph(edge_groups.get("", []), "") + + # Add remaining subgraphs + for prefix in edge_groups.keys(): + if ":" in prefix or prefix == "": + continue + add_subgraph(edge_groups[prefix], prefix) # Add custom styles for nodes if with_styles: diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 73b19f29e75..d1f1b20a49c 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1063,6 +1063,63 @@ ''' # --- +# name: test_parallel_subgraph_mermaid[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__]):::first + outer_1(outer_1) + inner_1_inner_1(inner_1) + inner_1_inner_2(inner_2
__interrupt = before) + inner_2_inner_1(inner_1) + inner_2_inner_2(inner_2) + outer_2(outer_2) + __end__([__end__]):::last + __start__ --> outer_1; + inner_1_inner_2 --> outer_2; + inner_2_inner_2 --> outer_2; + outer_1 --> inner_1_inner_1; + outer_1 --> inner_2_inner_1; + outer_2 --> __end__; + subgraph inner_1 + inner_1_inner_1 --> inner_1_inner_2; + end + subgraph inner_2 + inner_2_inner_1 --> inner_2_inner_2; + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- +# name: test_double_nested_subgraph_mermaid[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__]):::first + parent_1(parent_1) + child_child_1_grandchild_1(grandchild_1) + child_child_1_grandchild_2(grandchild_2
__interrupt = before) + child_child_2(child_2) + parent_2(parent_2) + __end__([__end__]):::last + __start__ --> parent_1; + child_child_2 --> parent_2; + parent_1 --> child_child_1_grandchild_1; + parent_2 --> __end__; + subgraph child + child_child_1_grandchild_2 --> child_child_2; + subgraph child_1 + child_child_1_grandchild_1 --> child_child_1_grandchild_2; + end + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_graph_single_runnable[ascii] ''' +----------------------+ diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 3d18a7f8f94..3e909040fc2 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -9,7 +9,7 @@ from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableConfig -from langchain_core.runnables.graph import Graph +from langchain_core.runnables.graph import Edge, Graph, Node from langchain_core.runnables.graph_mermaid import _escape_node_label from tests.unit_tests.pydantic_utils import _schema @@ -216,6 +216,136 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple") +def test_parallel_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: + empty_data = BaseModel + nodes = { + "__start__": Node( + id="__start__", name="__start__", data=empty_data, metadata=None + ), + "outer_1": Node(id="outer_1", name="outer_1", data=empty_data, metadata=None), + "inner_1:inner_1": Node( + id="inner_1:inner_1", name="inner_1", data=empty_data, metadata=None + ), + "inner_1:inner_2": Node( + id="inner_1:inner_2", + name="inner_2", + data=empty_data, + metadata={"__interrupt": "before"}, + ), + "inner_2:inner_1": Node( + id="inner_2:inner_1", name="inner_1", data=empty_data, metadata=None + ), + "inner_2:inner_2": Node( + id="inner_2:inner_2", name="inner_2", data=empty_data, metadata=None + ), + "outer_2": Node(id="outer_2", name="outer_2", data=empty_data, metadata=None), + "__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None), + } + edges = [ + Edge( + source="inner_1:inner_1", + target="inner_1:inner_2", + data=None, + conditional=False, + ), + Edge( + source="inner_2:inner_1", + target="inner_2:inner_2", + data=None, + conditional=False, + ), + Edge(source="__start__", target="outer_1", data=None, conditional=False), + Edge( + source="inner_1:inner_2", + target="outer_2", + data=None, + conditional=False, + ), + Edge( + source="inner_2:inner_2", + target="outer_2", + data=None, + conditional=False, + ), + Edge( + source="outer_1", + target="inner_1:inner_1", + data=None, + conditional=False, + ), + Edge( + source="outer_1", + target="inner_2:inner_1", + data=None, + conditional=False, + ), + Edge(source="outer_2", target="__end__", data=None, conditional=False), + ] + graph = Graph(nodes, edges) + assert graph.draw_mermaid() == snapshot(name="mermaid") + + +def test_double_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: + empty_data = BaseModel + nodes = { + "__start__": Node( + id="__start__", name="__start__", data=empty_data, metadata=None + ), + "parent_1": Node( + id="parent_1", name="parent_1", data=empty_data, metadata=None + ), + "child:child_1:grandchild_1": Node( + id="child:child_1:grandchild_1", + name="grandchild_1", + data=empty_data, + metadata=None, + ), + "child:child_1:grandchild_2": Node( + id="child:child_1:grandchild_2", + name="grandchild_2", + data=empty_data, + metadata={"__interrupt": "before"}, + ), + "child:child_2": Node( + id="child:child_2", name="child_2", data=empty_data, metadata=None + ), + "parent_2": Node( + id="parent_2", name="parent_2", data=empty_data, metadata=None + ), + "__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None), + } + edges = [ + Edge( + source="child:child_1:grandchild_1", + target="child:child_1:grandchild_2", + data=None, + conditional=False, + ), + Edge( + source="child:child_1:grandchild_2", + target="child:child_2", + data=None, + conditional=False, + ), + Edge(source="__start__", target="parent_1", data=None, conditional=False), + Edge( + source="child:child_2", + target="parent_2", + data=None, + conditional=False, + ), + Edge( + source="parent_1", + target="child:child_1:grandchild_1", + data=None, + conditional=False, + ), + Edge(source="parent_2", target="__end__", data=None, conditional=False), + ] + graph = Graph(nodes, edges) + assert graph.draw_mermaid() == snapshot(name="mermaid") + + def test_runnable_get_graph_with_invalid_input_type() -> None: """Test that error isn't raised when getting graph with invalid input type."""