Better support for subgraphs in graph viz (#20840)

This commit is contained in:
Nuno Campos 2024-04-24 12:32:52 -07:00 committed by GitHub
parent a9c7d47c03
commit 477eb1745c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 18 deletions

View File

@ -2409,8 +2409,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
step_graph.trim_first_node() step_graph.trim_first_node()
if step is not self.last: if step is not self.last:
step_graph.trim_last_node() step_graph.trim_last_node()
graph.extend(step_graph) step_first_node, _ = graph.extend(step_graph)
step_first_node = step_graph.first_node()
if not step_first_node: if not step_first_node:
raise ValueError(f"Runnable {step} has no first node") raise ValueError(f"Runnable {step} has no first node")
if current_last_node: if current_last_node:
@ -3082,11 +3081,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
if not step_graph: if not step_graph:
graph.add_edge(input_node, output_node) graph.add_edge(input_node, output_node)
else: else:
graph.extend(step_graph) step_first_node, step_last_node = graph.extend(step_graph)
step_first_node = step_graph.first_node()
if not step_first_node: if not step_first_node:
raise ValueError(f"Runnable {step} has no first node") raise ValueError(f"Runnable {step} has no first node")
step_last_node = step_graph.last_node()
if not step_last_node: if not step_last_node:
raise ValueError(f"Runnable {step} has no last node") raise ValueError(f"Runnable {step} has no last node")
graph.add_edge(input_node, step_first_node) graph.add_edge(input_node, step_first_node)
@ -3779,11 +3776,9 @@ class RunnableLambda(Runnable[Input, Output]):
if not dep_graph: if not dep_graph:
graph.add_edge(input_node, output_node) graph.add_edge(input_node, output_node)
else: else:
graph.extend(dep_graph) dep_first_node, dep_last_node = graph.extend(dep_graph)
dep_first_node = dep_graph.first_node()
if not dep_first_node: if not dep_first_node:
raise ValueError(f"Runnable {dep} has no first node") raise ValueError(f"Runnable {dep} has no first node")
dep_last_node = dep_graph.last_node()
if not dep_last_node: if not dep_last_node:
raise ValueError(f"Runnable {dep} has no last node") raise ValueError(f"Runnable {dep} has no last node")
graph.add_edge(input_node, dep_first_node) graph.add_edge(input_node, dep_first_node)

View File

@ -11,6 +11,7 @@ from typing import (
List, List,
NamedTuple, NamedTuple,
Optional, Optional,
Tuple,
Type, Type,
TypedDict, TypedDict,
Union, Union,
@ -236,11 +237,39 @@ class Graph:
self.edges.append(edge) self.edges.append(edge)
return edge return edge
def extend(self, graph: Graph) -> None: def extend(
self, graph: Graph, *, prefix: str = ""
) -> Tuple[Optional[Node], Optional[Node]]:
"""Add all nodes and edges from another graph. """Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs.""" Note this doesn't check for duplicates, nor does it connect the graphs."""
self.nodes.update(graph.nodes) if all(is_uuid(node.id) for node in graph.nodes.values()):
self.edges.extend(graph.edges) prefix = ""
def prefixed(id: str) -> str:
return f"{prefix}:{id}" if prefix else id
# prefix each node
self.nodes.update(
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
)
# prefix each edge's source and target
self.edges.extend(
[
Edge(
prefixed(edge.source),
prefixed(edge.target),
edge.data,
edge.conditional,
)
for edge in graph.edges
]
)
# return (prefixed) first and last nodes of the subgraph
first, last = graph.first_node(), graph.last_node()
return (
Node(prefixed(first.id), first.data) if first else None,
Node(prefixed(last.id), last.data) if last else None,
)
def first_node(self) -> Optional[Node]: def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge. """Find the single node that is not a target of any edge.

View File

@ -48,7 +48,7 @@ def draw_mermaid(
if with_styles: if with_styles:
# Node formatting templates # Node formatting templates
default_class_label = "default" default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"} format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
if first_node_label is not None: if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass" format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None: if last_node_label is not None:
@ -57,17 +57,24 @@ def draw_mermaid(
# Add nodes to the graph # Add nodes to the graph
for node in nodes.values(): for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format( node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node) _escape_node_label(node), _escape_node_label(node.split(":", 1)[-1])
) )
mermaid_graph += f"\t{node_label};\n" mermaid_graph += f"\t{node_label};\n"
subgraph = ""
# Add edges to the graph # Add edges to the graph
for edge in edges: for edge in edges:
src_prefix = edge.source.split(":")[0]
tgt_prefix = edge.target.split(":")[0]
# exit subgraph if source or target is not in the same subgraph
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
mermaid_graph += "\tend\n"
subgraph = ""
# enter subgraph if source and target are in the same subgraph
if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes) adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
if (
adjusted_edge is None
): # Ignore if it is connection between source and intermediate node
continue
source, target = adjusted_edge source, target = adjusted_edge
@ -96,6 +103,8 @@ def draw_mermaid(
f"\t{_escape_node_label(source)}{edge_label}" f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n" f"{_escape_node_label(target)};\n"
) )
if subgraph:
mermaid_graph += "end\n"
# Add custom styles for nodes # Add custom styles for nodes
if with_styles: if with_styles:
@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str:
def _adjust_mermaid_edge( def _adjust_mermaid_edge(
edge: Edge, edge: Edge,
nodes: Dict[str, str], nodes: Dict[str, str],
) -> Optional[Tuple[str, str]]: ) -> Tuple[str, str]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes.""" """Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source) source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target) target_node_label = nodes.get(edge.target, edge.target)