From 477eb1745c1ad69d755330f508e7f8b46359e2d4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 24 Apr 2024 12:32:52 -0700 Subject: [PATCH] Better support for subgraphs in graph viz (#20840) --- libs/core/langchain_core/runnables/base.py | 11 ++---- libs/core/langchain_core/runnables/graph.py | 35 +++++++++++++++++-- .../langchain_core/runnables/graph_mermaid.py | 23 ++++++++---- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index ee48234df3a..3484e8bba3c 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2409,8 +2409,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): step_graph.trim_first_node() if step is not self.last: step_graph.trim_last_node() - graph.extend(step_graph) - step_first_node = step_graph.first_node() + step_first_node, _ = graph.extend(step_graph) if not step_first_node: raise ValueError(f"Runnable {step} has no first node") if current_last_node: @@ -3082,11 +3081,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): if not step_graph: graph.add_edge(input_node, output_node) else: - graph.extend(step_graph) - step_first_node = step_graph.first_node() + step_first_node, step_last_node = graph.extend(step_graph) if not step_first_node: raise ValueError(f"Runnable {step} has no first node") - step_last_node = step_graph.last_node() if not step_last_node: raise ValueError(f"Runnable {step} has no last node") graph.add_edge(input_node, step_first_node) @@ -3779,11 +3776,9 @@ class RunnableLambda(Runnable[Input, Output]): if not dep_graph: graph.add_edge(input_node, output_node) else: - graph.extend(dep_graph) - dep_first_node = dep_graph.first_node() + dep_first_node, dep_last_node = graph.extend(dep_graph) if not dep_first_node: raise ValueError(f"Runnable {dep} has no first node") - dep_last_node = dep_graph.last_node() if not dep_last_node: raise ValueError(f"Runnable {dep} has no last node") graph.add_edge(input_node, dep_first_node) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index f92b4b064a2..4486fca5252 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -11,6 +11,7 @@ from typing import ( List, NamedTuple, Optional, + Tuple, Type, TypedDict, Union, @@ -236,11 +237,39 @@ class Graph: self.edges.append(edge) return edge - def extend(self, graph: Graph) -> None: + def extend( + self, graph: Graph, *, prefix: str = "" + ) -> Tuple[Optional[Node], Optional[Node]]: """Add all nodes and edges from another graph. Note this doesn't check for duplicates, nor does it connect the graphs.""" - self.nodes.update(graph.nodes) - self.edges.extend(graph.edges) + if all(is_uuid(node.id) for node in graph.nodes.values()): + prefix = "" + + def prefixed(id: str) -> str: + return f"{prefix}:{id}" if prefix else id + + # prefix each node + self.nodes.update( + {prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()} + ) + # prefix each edge's source and target + self.edges.extend( + [ + Edge( + prefixed(edge.source), + prefixed(edge.target), + edge.data, + edge.conditional, + ) + for edge in graph.edges + ] + ) + # return (prefixed) first and last nodes of the subgraph + first, last = graph.first_node(), graph.last_node() + return ( + Node(prefixed(first.id), first.data) if first else None, + Node(prefixed(last.id), last.data) if last else None, + ) def first_node(self) -> Optional[Node]: """Find the single node that is not a target of any edge. diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 93e052d8919..61b922ae68f 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -48,7 +48,7 @@ def draw_mermaid( if with_styles: # Node formatting templates default_class_label = "default" - format_dict = {default_class_label: "{0}([{0}]):::otherclass"} + format_dict = {default_class_label: "{0}([{1}]):::otherclass"} if first_node_label is not None: format_dict[first_node_label] = "{0}[{0}]:::startclass" if last_node_label is not None: @@ -57,17 +57,24 @@ def draw_mermaid( # Add nodes to the graph for node in nodes.values(): node_label = format_dict.get(node, format_dict[default_class_label]).format( - _escape_node_label(node) + _escape_node_label(node), _escape_node_label(node.split(":", 1)[-1]) ) mermaid_graph += f"\t{node_label};\n" + subgraph = "" # Add edges to the graph for edge in edges: + src_prefix = edge.source.split(":")[0] + tgt_prefix = edge.target.split(":")[0] + # exit subgraph if source or target is not in the same subgraph + if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix): + mermaid_graph += "\tend\n" + subgraph = "" + # enter subgraph if source and target are in the same subgraph + if not subgraph and src_prefix and src_prefix == tgt_prefix: + mermaid_graph += f"\tsubgraph {src_prefix}\n" + subgraph = src_prefix adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes) - if ( - adjusted_edge is None - ): # Ignore if it is connection between source and intermediate node - continue source, target = adjusted_edge @@ -96,6 +103,8 @@ def draw_mermaid( f"\t{_escape_node_label(source)}{edge_label}" f"{_escape_node_label(target)};\n" ) + if subgraph: + mermaid_graph += "end\n" # Add custom styles for nodes if with_styles: @@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str: def _adjust_mermaid_edge( edge: Edge, nodes: Dict[str, str], -) -> Optional[Tuple[str, str]]: +) -> Tuple[str, str]: """Adjusts Mermaid edge to map conditional nodes to pure nodes.""" source_node_label = nodes.get(edge.source, edge.source) target_node_label = nodes.get(edge.target, edge.target)