core[patch]: support drawing nested subgraphs in draw_mermaid (#25581)

Previously the code was able to only handle a single level of nesting
for subgraphs in mermaid. This change adds support for arbitrary nesting
of subgraphs.
This commit is contained in:
Vadym Barda
2024-08-22 19:08:49 -04:00
committed by GitHub
parent 1c31234eed
commit 46d344c33d
3 changed files with 258 additions and 40 deletions

View File

@@ -1063,6 +1063,63 @@
'''
# ---
# name: test_parallel_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
__start__([__start__]):::first
outer_1(outer_1)
inner_1_inner_1(inner_1)
inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>)
inner_2_inner_1(inner_1)
inner_2_inner_2(inner_2)
outer_2(outer_2)
__end__([__end__]):::last
__start__ --> outer_1;
inner_1_inner_2 --> outer_2;
inner_2_inner_2 --> outer_2;
outer_1 --> inner_1_inner_1;
outer_1 --> inner_2_inner_1;
outer_2 --> __end__;
subgraph inner_1
inner_1_inner_1 --> inner_1_inner_2;
end
subgraph inner_2
inner_2_inner_1 --> inner_2_inner_2;
end
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_double_nested_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
__start__([__start__]):::first
parent_1(parent_1)
child_child_1_grandchild_1(grandchild_1)
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
child_child_2(child_2)
parent_2(parent_2)
__end__([__end__]):::last
__start__ --> parent_1;
child_child_2 --> parent_2;
parent_1 --> child_child_1_grandchild_1;
parent_2 --> __end__;
subgraph child
child_child_1_grandchild_2 --> child_child_2;
subgraph child_1
child_child_1_grandchild_1 --> child_child_1_grandchild_2;
end
end
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_graph_single_runnable[ascii]
'''
+----------------------+

View File

@@ -9,7 +9,7 @@ from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.graph import Edge, Graph, Node
from langchain_core.runnables.graph_mermaid import _escape_node_label
from tests.unit_tests.pydantic_utils import _schema
@@ -216,6 +216,136 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")
def test_parallel_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
empty_data = BaseModel
nodes = {
"__start__": Node(
id="__start__", name="__start__", data=empty_data, metadata=None
),
"outer_1": Node(id="outer_1", name="outer_1", data=empty_data, metadata=None),
"inner_1:inner_1": Node(
id="inner_1:inner_1", name="inner_1", data=empty_data, metadata=None
),
"inner_1:inner_2": Node(
id="inner_1:inner_2",
name="inner_2",
data=empty_data,
metadata={"__interrupt": "before"},
),
"inner_2:inner_1": Node(
id="inner_2:inner_1", name="inner_1", data=empty_data, metadata=None
),
"inner_2:inner_2": Node(
id="inner_2:inner_2", name="inner_2", data=empty_data, metadata=None
),
"outer_2": Node(id="outer_2", name="outer_2", data=empty_data, metadata=None),
"__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None),
}
edges = [
Edge(
source="inner_1:inner_1",
target="inner_1:inner_2",
data=None,
conditional=False,
),
Edge(
source="inner_2:inner_1",
target="inner_2:inner_2",
data=None,
conditional=False,
),
Edge(source="__start__", target="outer_1", data=None, conditional=False),
Edge(
source="inner_1:inner_2",
target="outer_2",
data=None,
conditional=False,
),
Edge(
source="inner_2:inner_2",
target="outer_2",
data=None,
conditional=False,
),
Edge(
source="outer_1",
target="inner_1:inner_1",
data=None,
conditional=False,
),
Edge(
source="outer_1",
target="inner_2:inner_1",
data=None,
conditional=False,
),
Edge(source="outer_2", target="__end__", data=None, conditional=False),
]
graph = Graph(nodes, edges)
assert graph.draw_mermaid() == snapshot(name="mermaid")
def test_double_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
empty_data = BaseModel
nodes = {
"__start__": Node(
id="__start__", name="__start__", data=empty_data, metadata=None
),
"parent_1": Node(
id="parent_1", name="parent_1", data=empty_data, metadata=None
),
"child:child_1:grandchild_1": Node(
id="child:child_1:grandchild_1",
name="grandchild_1",
data=empty_data,
metadata=None,
),
"child:child_1:grandchild_2": Node(
id="child:child_1:grandchild_2",
name="grandchild_2",
data=empty_data,
metadata={"__interrupt": "before"},
),
"child:child_2": Node(
id="child:child_2", name="child_2", data=empty_data, metadata=None
),
"parent_2": Node(
id="parent_2", name="parent_2", data=empty_data, metadata=None
),
"__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None),
}
edges = [
Edge(
source="child:child_1:grandchild_1",
target="child:child_1:grandchild_2",
data=None,
conditional=False,
),
Edge(
source="child:child_1:grandchild_2",
target="child:child_2",
data=None,
conditional=False,
),
Edge(source="__start__", target="parent_1", data=None, conditional=False),
Edge(
source="child:child_2",
target="parent_2",
data=None,
conditional=False,
),
Edge(
source="parent_1",
target="child:child_1:grandchild_1",
data=None,
conditional=False,
),
Edge(source="parent_2", target="__end__", data=None, conditional=False),
]
graph = Graph(nodes, edges)
assert graph.draw_mermaid() == snapshot(name="mermaid")
def test_runnable_get_graph_with_invalid_input_type() -> None:
"""Test that error isn't raised when getting graph with invalid input type."""