Better support for subgraphs in graph viz (#20840)

This commit is contained in:
Nuno Campos 2024-04-24 12:32:52 -07:00 committed by GitHub
parent a9c7d47c03
commit 477eb1745c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 18 deletions

View File

@ -2409,8 +2409,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
step_graph.trim_first_node()
if step is not self.last:
step_graph.trim_last_node()
graph.extend(step_graph)
step_first_node = step_graph.first_node()
step_first_node, _ = graph.extend(step_graph)
if not step_first_node:
raise ValueError(f"Runnable {step} has no first node")
if current_last_node:
@ -3082,11 +3081,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
if not step_graph:
graph.add_edge(input_node, output_node)
else:
graph.extend(step_graph)
step_first_node = step_graph.first_node()
step_first_node, step_last_node = graph.extend(step_graph)
if not step_first_node:
raise ValueError(f"Runnable {step} has no first node")
step_last_node = step_graph.last_node()
if not step_last_node:
raise ValueError(f"Runnable {step} has no last node")
graph.add_edge(input_node, step_first_node)
@ -3779,11 +3776,9 @@ class RunnableLambda(Runnable[Input, Output]):
if not dep_graph:
graph.add_edge(input_node, output_node)
else:
graph.extend(dep_graph)
dep_first_node = dep_graph.first_node()
dep_first_node, dep_last_node = graph.extend(dep_graph)
if not dep_first_node:
raise ValueError(f"Runnable {dep} has no first node")
dep_last_node = dep_graph.last_node()
if not dep_last_node:
raise ValueError(f"Runnable {dep} has no last node")
graph.add_edge(input_node, dep_first_node)

View File

@ -11,6 +11,7 @@ from typing import (
List,
NamedTuple,
Optional,
Tuple,
Type,
TypedDict,
Union,
@ -236,11 +237,39 @@ class Graph:
self.edges.append(edge)
return edge
def extend(self, graph: Graph) -> None:
def extend(
self, graph: Graph, *, prefix: str = ""
) -> Tuple[Optional[Node], Optional[Node]]:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs."""
self.nodes.update(graph.nodes)
self.edges.extend(graph.edges)
if all(is_uuid(node.id) for node in graph.nodes.values()):
prefix = ""
def prefixed(id: str) -> str:
return f"{prefix}:{id}" if prefix else id
# prefix each node
self.nodes.update(
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
)
# prefix each edge's source and target
self.edges.extend(
[
Edge(
prefixed(edge.source),
prefixed(edge.target),
edge.data,
edge.conditional,
)
for edge in graph.edges
]
)
# return (prefixed) first and last nodes of the subgraph
first, last = graph.first_node(), graph.last_node()
return (
Node(prefixed(first.id), first.data) if first else None,
Node(prefixed(last.id), last.data) if last else None,
)
def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.

View File

@ -48,7 +48,7 @@ def draw_mermaid(
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
@ -57,17 +57,24 @@ def draw_mermaid(
# Add nodes to the graph
for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node)
_escape_node_label(node), _escape_node_label(node.split(":", 1)[-1])
)
mermaid_graph += f"\t{node_label};\n"
subgraph = ""
# Add edges to the graph
for edge in edges:
src_prefix = edge.source.split(":")[0]
tgt_prefix = edge.target.split(":")[0]
# exit subgraph if source or target is not in the same subgraph
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
mermaid_graph += "\tend\n"
subgraph = ""
# enter subgraph if source and target are in the same subgraph
if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
if (
adjusted_edge is None
): # Ignore if it is connection between source and intermediate node
continue
source, target = adjusted_edge
@ -96,6 +103,8 @@ def draw_mermaid(
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
)
if subgraph:
mermaid_graph += "end\n"
# Add custom styles for nodes
if with_styles:
@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str:
def _adjust_mermaid_edge(
edge: Edge,
nodes: Dict[str, str],
) -> Optional[Tuple[str, str]]:
) -> Tuple[str, str]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target)