core[patch]: support single-node subgraphs and put subgraph nodes under the respective subgraphs (#30234)

This commit is contained in:
Vadym Barda
2025-03-11 18:55:45 -04:00
committed by GitHub
parent 81d1653a30
commit c7842730ef
3 changed files with 105 additions and 39 deletions

View File

@@ -56,37 +56,50 @@ def draw_mermaid(
if with_styles
else "graph TD;\n"
)
# Group nodes by subgraph
subgraph_nodes: dict[str, dict[str, Node]] = {}
regular_nodes: dict[str, Node] = {}
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}({1})"}
if first_node is not None:
format_dict[first_node] = "{0}([{1}]):::first"
if last_node is not None:
format_dict[last_node] = "{0}([{1}]):::last"
for key, node in nodes.items():
if ":" in key:
# For nodes with colons, add them only to their deepest subgraph level
prefix = ":".join(key.split(":")[:-1])
subgraph_nodes.setdefault(prefix, {})[key] = node
else:
regular_nodes[key] = node
# Add nodes to the graph
for key, node in nodes.items():
node_name = node.name.split(":")[-1]
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}({1})"}
if first_node is not None:
format_dict[first_node] = "{0}([{1}]):::first"
if last_node is not None:
format_dict[last_node] = "{0}([{1}]):::last"
def render_node(key: str, node: Node, indent: str = "\t") -> str:
"""Helper function to render a node with consistent formatting."""
node_name = node.name.split(":")[-1]
label = (
f"<p>{node_name}</p>"
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
else node_name
)
if node.metadata:
label = (
f"<p>{node_name}</p>"
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
else node_name
f"{label}<hr/><small><em>"
+ "\n".join(f"{k} = {value}" for k, value in node.metadata.items())
+ "</em></small>"
)
if node.metadata:
label = (
f"{label}<hr/><small><em>"
+ "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
+ "</em></small>"
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
mermaid_graph += f"\t{node_label}\n"
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
return f"{indent}{node_label}\n"
# Add non-subgraph nodes to the graph
if with_styles:
for key, node in regular_nodes.items():
mermaid_graph += render_node(key, node)
# Group edges by their common prefixes
edge_groups: dict[str, list[Edge]] = {}
@@ -116,6 +129,11 @@ def draw_mermaid(
seen_subgraphs.add(subgraph)
mermaid_graph += f"\tsubgraph {subgraph}\n"
# Add nodes that belong to this subgraph
if with_styles and prefix in subgraph_nodes:
for key, node in subgraph_nodes[prefix].items():
mermaid_graph += render_node(key, node)
for edge in edges:
source, target = edge.source, edge.target
@@ -156,11 +174,25 @@ def draw_mermaid(
# Start with the top-level edges (no common prefix)
add_subgraph(edge_groups.get("", []), "")
# Add remaining subgraphs
# Add remaining subgraphs with edges
for prefix in edge_groups:
if ":" in prefix or prefix == "":
continue
add_subgraph(edge_groups[prefix], prefix)
seen_subgraphs.add(prefix)
# Add empty subgraphs (subgraphs with no internal edges)
if with_styles:
for prefix in subgraph_nodes:
if ":" not in prefix and prefix not in seen_subgraphs:
mermaid_graph += f"\tsubgraph {prefix}\n"
# Add nodes that belong to this subgraph
for key, node in subgraph_nodes[prefix].items():
mermaid_graph += render_node(key, node)
mermaid_graph += "\tend\n"
seen_subgraphs.add(prefix)
# Add custom styles for nodes
if with_styles: