From c7842730efbeddab76f0d04fb76c658258f7c61e Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Tue, 11 Mar 2025 18:55:45 -0400 Subject: [PATCH] core[patch]: support single-node subgraphs and put subgraph nodes under the respective subgraphs (#30234) --- .../langchain_core/runnables/graph_mermaid.py | 88 +++++++++++++------ .../runnables/__snapshots__/test_graph.ambr | 39 +++++--- .../tests/unit_tests/runnables/test_graph.py | 17 ++++ 3 files changed, 105 insertions(+), 39 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 4693558cc0d..af4f806c1de 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -56,37 +56,50 @@ def draw_mermaid( if with_styles else "graph TD;\n" ) + # Group nodes by subgraph + subgraph_nodes: dict[str, dict[str, Node]] = {} + regular_nodes: dict[str, Node] = {} - if with_styles: - # Node formatting templates - default_class_label = "default" - format_dict = {default_class_label: "{0}({1})"} - if first_node is not None: - format_dict[first_node] = "{0}([{1}]):::first" - if last_node is not None: - format_dict[last_node] = "{0}([{1}]):::last" + for key, node in nodes.items(): + if ":" in key: + # For nodes with colons, add them only to their deepest subgraph level + prefix = ":".join(key.split(":")[:-1]) + subgraph_nodes.setdefault(prefix, {})[key] = node + else: + regular_nodes[key] = node - # Add nodes to the graph - for key, node in nodes.items(): - node_name = node.name.split(":")[-1] + # Node formatting templates + default_class_label = "default" + format_dict = {default_class_label: "{0}({1})"} + if first_node is not None: + format_dict[first_node] = "{0}([{1}]):::first" + if last_node is not None: + format_dict[last_node] = "{0}([{1}]):::last" + + def render_node(key: str, node: Node, indent: str = "\t") -> str: + """Helper function to render a node with consistent formatting.""" + node_name = node.name.split(":")[-1] + label = ( + f"

{node_name}

" + if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) + and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) + else node_name + ) + if node.metadata: label = ( - f"

{node_name}

" - if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) - and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) - else node_name + f"{label}
" + + "\n".join(f"{k} = {value}" for k, value in node.metadata.items()) + + "" ) - if node.metadata: - label = ( - f"{label}
" - + "\n".join( - f"{key} = {value}" for key, value in node.metadata.items() - ) - + "" - ) - node_label = format_dict.get(key, format_dict[default_class_label]).format( - _escape_node_label(key), label - ) - mermaid_graph += f"\t{node_label}\n" + node_label = format_dict.get(key, format_dict[default_class_label]).format( + _escape_node_label(key), label + ) + return f"{indent}{node_label}\n" + + # Add non-subgraph nodes to the graph + if with_styles: + for key, node in regular_nodes.items(): + mermaid_graph += render_node(key, node) # Group edges by their common prefixes edge_groups: dict[str, list[Edge]] = {} @@ -116,6 +129,11 @@ def draw_mermaid( seen_subgraphs.add(subgraph) mermaid_graph += f"\tsubgraph {subgraph}\n" + # Add nodes that belong to this subgraph + if with_styles and prefix in subgraph_nodes: + for key, node in subgraph_nodes[prefix].items(): + mermaid_graph += render_node(key, node) + for edge in edges: source, target = edge.source, edge.target @@ -156,11 +174,25 @@ def draw_mermaid( # Start with the top-level edges (no common prefix) add_subgraph(edge_groups.get("", []), "") - # Add remaining subgraphs + # Add remaining subgraphs with edges for prefix in edge_groups: if ":" in prefix or prefix == "": continue add_subgraph(edge_groups[prefix], prefix) + seen_subgraphs.add(prefix) + + # Add empty subgraphs (subgraphs with no internal edges) + if with_styles: + for prefix in subgraph_nodes: + if ":" not in prefix and prefix not in seen_subgraphs: + mermaid_graph += f"\tsubgraph {prefix}\n" + + # Add nodes that belong to this subgraph + for key, node in subgraph_nodes[prefix].items(): + mermaid_graph += render_node(key, node) + + mermaid_graph += "\tend\n" + seen_subgraphs.add(prefix) # Add custom styles for nodes if with_styles: diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 4b4568571b3..2e4a19ce5c2 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -5,9 +5,6 @@ graph TD; __start__([

__start__

]):::first parent_1(parent_1) - child_child_1_grandchild_1(grandchild_1) - child_child_1_grandchild_2(grandchild_2
__interrupt = before) - child_child_2(child_2) parent_2(parent_2) __end__([

__end__

]):::last __start__ --> parent_1; @@ -15,8 +12,11 @@ parent_1 --> child_child_1_grandchild_1; parent_2 --> __end__; subgraph child + child_child_2(child_2) child_child_1_grandchild_2 --> child_child_2; subgraph child_1 + child_child_1_grandchild_1(grandchild_1) + child_child_1_grandchild_2(grandchild_2
__interrupt = before) child_child_1_grandchild_1 --> child_child_1_grandchild_2; end end @@ -32,10 +32,6 @@ graph TD; __start__([

__start__

]):::first parent_1(parent_1) - child_child_1_grandchild_1(grandchild_1) - child_child_1_grandchild_1_greatgrandchild(greatgrandchild) - child_child_1_grandchild_2(grandchild_2
__interrupt = before) - child_child_2(child_2) parent_2(parent_2) __end__([

__end__

]):::last __start__ --> parent_1; @@ -43,10 +39,14 @@ parent_1 --> child_child_1_grandchild_1; parent_2 --> __end__; subgraph child + child_child_2(child_2) child_child_1_grandchild_2 --> child_child_2; subgraph child_1 + child_child_1_grandchild_1(grandchild_1) + child_child_1_grandchild_2(grandchild_2
__interrupt = before) child_child_1_grandchild_1_greatgrandchild --> child_child_1_grandchild_2; subgraph grandchild_1 + child_child_1_grandchild_1_greatgrandchild(greatgrandchild) child_child_1_grandchild_1 --> child_child_1_grandchild_1_greatgrandchild; end end @@ -1996,10 +1996,6 @@ graph TD; __start__([

__start__

]):::first outer_1(outer_1) - inner_1_inner_1(inner_1) - inner_1_inner_2(inner_2
__interrupt = before) - inner_2_inner_1(inner_1) - inner_2_inner_2(inner_2) outer_2(outer_2) __end__([

__end__

]):::last __start__ --> outer_1; @@ -2009,9 +2005,13 @@ outer_1 --> inner_2_inner_1; outer_2 --> __end__; subgraph inner_1 + inner_1_inner_1(inner_1) + inner_1_inner_2(inner_2
__interrupt = before) inner_1_inner_1 --> inner_1_inner_2; end subgraph inner_2 + inner_2_inner_1(inner_1) + inner_2_inner_2(inner_2) inner_2_inner_1 --> inner_2_inner_2; end classDef default fill:#f2f0ff,line-height:1.2 @@ -2020,6 +2020,23 @@ ''' # --- +# name: test_single_node_subgraph_mermaid[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([

__start__

]):::first + __end__([

__end__

]):::last + __start__ --> sub_meow; + sub_meow --> __end__; + subgraph sub + sub_meow(meow) + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_trim dict({ 'edges': list([ diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index c2f7ef9b7dc..6f822c1e7c2 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -448,6 +448,23 @@ def test_triple_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: assert graph.draw_mermaid() == snapshot(name="mermaid") +def test_single_node_subgraph_mermaid(snapshot: SnapshotAssertion) -> None: + empty_data = BaseModel + nodes = { + "__start__": Node( + id="__start__", name="__start__", data=empty_data, metadata=None + ), + "sub:meow": Node(id="sub:meow", name="meow", data=empty_data, metadata=None), + "__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None), + } + edges = [ + Edge(source="__start__", target="sub:meow", data=None, conditional=False), + Edge(source="sub:meow", target="__end__", data=None, conditional=False), + ] + graph = Graph(nodes, edges) + assert graph.draw_mermaid() == snapshot(name="mermaid") + + def test_runnable_get_graph_with_invalid_input_type() -> None: """Test that error isn't raised when getting graph with invalid input type."""