mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
core[patch]: support single-node subgraphs and put subgraph nodes under the respective subgraphs (#30234)
This commit is contained in:
parent
81d1653a30
commit
c7842730ef
@ -56,37 +56,50 @@ def draw_mermaid(
|
||||
if with_styles
|
||||
else "graph TD;\n"
|
||||
)
|
||||
# Group nodes by subgraph
|
||||
subgraph_nodes: dict[str, dict[str, Node]] = {}
|
||||
regular_nodes: dict[str, Node] = {}
|
||||
|
||||
if with_styles:
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}({1})"}
|
||||
if first_node is not None:
|
||||
format_dict[first_node] = "{0}([{1}]):::first"
|
||||
if last_node is not None:
|
||||
format_dict[last_node] = "{0}([{1}]):::last"
|
||||
for key, node in nodes.items():
|
||||
if ":" in key:
|
||||
# For nodes with colons, add them only to their deepest subgraph level
|
||||
prefix = ":".join(key.split(":")[:-1])
|
||||
subgraph_nodes.setdefault(prefix, {})[key] = node
|
||||
else:
|
||||
regular_nodes[key] = node
|
||||
|
||||
# Add nodes to the graph
|
||||
for key, node in nodes.items():
|
||||
node_name = node.name.split(":")[-1]
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}({1})"}
|
||||
if first_node is not None:
|
||||
format_dict[first_node] = "{0}([{1}]):::first"
|
||||
if last_node is not None:
|
||||
format_dict[last_node] = "{0}([{1}]):::last"
|
||||
|
||||
def render_node(key: str, node: Node, indent: str = "\t") -> str:
|
||||
"""Helper function to render a node with consistent formatting."""
|
||||
node_name = node.name.split(":")[-1]
|
||||
label = (
|
||||
f"<p>{node_name}</p>"
|
||||
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
|
||||
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
|
||||
else node_name
|
||||
)
|
||||
if node.metadata:
|
||||
label = (
|
||||
f"<p>{node_name}</p>"
|
||||
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
|
||||
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
|
||||
else node_name
|
||||
f"{label}<hr/><small><em>"
|
||||
+ "\n".join(f"{k} = {value}" for k, value in node.metadata.items())
|
||||
+ "</em></small>"
|
||||
)
|
||||
if node.metadata:
|
||||
label = (
|
||||
f"{label}<hr/><small><em>"
|
||||
+ "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
)
|
||||
+ "</em></small>"
|
||||
)
|
||||
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||
_escape_node_label(key), label
|
||||
)
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||
_escape_node_label(key), label
|
||||
)
|
||||
return f"{indent}{node_label}\n"
|
||||
|
||||
# Add non-subgraph nodes to the graph
|
||||
if with_styles:
|
||||
for key, node in regular_nodes.items():
|
||||
mermaid_graph += render_node(key, node)
|
||||
|
||||
# Group edges by their common prefixes
|
||||
edge_groups: dict[str, list[Edge]] = {}
|
||||
@ -116,6 +129,11 @@ def draw_mermaid(
|
||||
seen_subgraphs.add(subgraph)
|
||||
mermaid_graph += f"\tsubgraph {subgraph}\n"
|
||||
|
||||
# Add nodes that belong to this subgraph
|
||||
if with_styles and prefix in subgraph_nodes:
|
||||
for key, node in subgraph_nodes[prefix].items():
|
||||
mermaid_graph += render_node(key, node)
|
||||
|
||||
for edge in edges:
|
||||
source, target = edge.source, edge.target
|
||||
|
||||
@ -156,11 +174,25 @@ def draw_mermaid(
|
||||
# Start with the top-level edges (no common prefix)
|
||||
add_subgraph(edge_groups.get("", []), "")
|
||||
|
||||
# Add remaining subgraphs
|
||||
# Add remaining subgraphs with edges
|
||||
for prefix in edge_groups:
|
||||
if ":" in prefix or prefix == "":
|
||||
continue
|
||||
add_subgraph(edge_groups[prefix], prefix)
|
||||
seen_subgraphs.add(prefix)
|
||||
|
||||
# Add empty subgraphs (subgraphs with no internal edges)
|
||||
if with_styles:
|
||||
for prefix in subgraph_nodes:
|
||||
if ":" not in prefix and prefix not in seen_subgraphs:
|
||||
mermaid_graph += f"\tsubgraph {prefix}\n"
|
||||
|
||||
# Add nodes that belong to this subgraph
|
||||
for key, node in subgraph_nodes[prefix].items():
|
||||
mermaid_graph += render_node(key, node)
|
||||
|
||||
mermaid_graph += "\tend\n"
|
||||
seen_subgraphs.add(prefix)
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
|
@ -5,9 +5,6 @@
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
parent_1(parent_1)
|
||||
child_child_1_grandchild_1(grandchild_1)
|
||||
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
child_child_2(child_2)
|
||||
parent_2(parent_2)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> parent_1;
|
||||
@ -15,8 +12,11 @@
|
||||
parent_1 --> child_child_1_grandchild_1;
|
||||
parent_2 --> __end__;
|
||||
subgraph child
|
||||
child_child_2(child_2)
|
||||
child_child_1_grandchild_2 --> child_child_2;
|
||||
subgraph child_1
|
||||
child_child_1_grandchild_1(grandchild_1)
|
||||
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
child_child_1_grandchild_1 --> child_child_1_grandchild_2;
|
||||
end
|
||||
end
|
||||
@ -32,10 +32,6 @@
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
parent_1(parent_1)
|
||||
child_child_1_grandchild_1(grandchild_1)
|
||||
child_child_1_grandchild_1_greatgrandchild(greatgrandchild)
|
||||
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
child_child_2(child_2)
|
||||
parent_2(parent_2)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> parent_1;
|
||||
@ -43,10 +39,14 @@
|
||||
parent_1 --> child_child_1_grandchild_1;
|
||||
parent_2 --> __end__;
|
||||
subgraph child
|
||||
child_child_2(child_2)
|
||||
child_child_1_grandchild_2 --> child_child_2;
|
||||
subgraph child_1
|
||||
child_child_1_grandchild_1(grandchild_1)
|
||||
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
child_child_1_grandchild_1_greatgrandchild --> child_child_1_grandchild_2;
|
||||
subgraph grandchild_1
|
||||
child_child_1_grandchild_1_greatgrandchild(greatgrandchild)
|
||||
child_child_1_grandchild_1 --> child_child_1_grandchild_1_greatgrandchild;
|
||||
end
|
||||
end
|
||||
@ -1996,10 +1996,6 @@
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
outer_1(outer_1)
|
||||
inner_1_inner_1(inner_1)
|
||||
inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
inner_2_inner_1(inner_1)
|
||||
inner_2_inner_2(inner_2)
|
||||
outer_2(outer_2)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> outer_1;
|
||||
@ -2009,9 +2005,13 @@
|
||||
outer_1 --> inner_2_inner_1;
|
||||
outer_2 --> __end__;
|
||||
subgraph inner_1
|
||||
inner_1_inner_1(inner_1)
|
||||
inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>)
|
||||
inner_1_inner_1 --> inner_1_inner_2;
|
||||
end
|
||||
subgraph inner_2
|
||||
inner_2_inner_1(inner_1)
|
||||
inner_2_inner_2(inner_2)
|
||||
inner_2_inner_1 --> inner_2_inner_2;
|
||||
end
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
@ -2020,6 +2020,23 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_single_node_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> sub_meow;
|
||||
sub_meow --> __end__;
|
||||
subgraph sub
|
||||
sub_meow(meow)
|
||||
end
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_trim
|
||||
dict({
|
||||
'edges': list([
|
||||
|
@ -448,6 +448,23 @@ def test_triple_nested_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_single_node_subgraph_mermaid(snapshot: SnapshotAssertion) -> None:
|
||||
empty_data = BaseModel
|
||||
nodes = {
|
||||
"__start__": Node(
|
||||
id="__start__", name="__start__", data=empty_data, metadata=None
|
||||
),
|
||||
"sub:meow": Node(id="sub:meow", name="meow", data=empty_data, metadata=None),
|
||||
"__end__": Node(id="__end__", name="__end__", data=empty_data, metadata=None),
|
||||
}
|
||||
edges = [
|
||||
Edge(source="__start__", target="sub:meow", data=None, conditional=False),
|
||||
Edge(source="sub:meow", target="__end__", data=None, conditional=False),
|
||||
]
|
||||
graph = Graph(nodes, edges)
|
||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
"""Test that error isn't raised when getting graph with invalid input type."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user