mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Better support for subgraphs in graph viz (#20840)
This commit is contained in:
parent
a9c7d47c03
commit
477eb1745c
@ -2409,8 +2409,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
step_graph.trim_first_node()
|
||||
if step is not self.last:
|
||||
step_graph.trim_last_node()
|
||||
graph.extend(step_graph)
|
||||
step_first_node = step_graph.first_node()
|
||||
step_first_node, _ = graph.extend(step_graph)
|
||||
if not step_first_node:
|
||||
raise ValueError(f"Runnable {step} has no first node")
|
||||
if current_last_node:
|
||||
@ -3082,11 +3081,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
if not step_graph:
|
||||
graph.add_edge(input_node, output_node)
|
||||
else:
|
||||
graph.extend(step_graph)
|
||||
step_first_node = step_graph.first_node()
|
||||
step_first_node, step_last_node = graph.extend(step_graph)
|
||||
if not step_first_node:
|
||||
raise ValueError(f"Runnable {step} has no first node")
|
||||
step_last_node = step_graph.last_node()
|
||||
if not step_last_node:
|
||||
raise ValueError(f"Runnable {step} has no last node")
|
||||
graph.add_edge(input_node, step_first_node)
|
||||
@ -3779,11 +3776,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if not dep_graph:
|
||||
graph.add_edge(input_node, output_node)
|
||||
else:
|
||||
graph.extend(dep_graph)
|
||||
dep_first_node = dep_graph.first_node()
|
||||
dep_first_node, dep_last_node = graph.extend(dep_graph)
|
||||
if not dep_first_node:
|
||||
raise ValueError(f"Runnable {dep} has no first node")
|
||||
dep_last_node = dep_graph.last_node()
|
||||
if not dep_last_node:
|
||||
raise ValueError(f"Runnable {dep} has no last node")
|
||||
graph.add_edge(input_node, dep_first_node)
|
||||
|
@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
@ -236,11 +237,39 @@ class Graph:
|
||||
self.edges.append(edge)
|
||||
return edge
|
||||
|
||||
def extend(self, graph: Graph) -> None:
|
||||
def extend(
|
||||
self, graph: Graph, *, prefix: str = ""
|
||||
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||
"""Add all nodes and edges from another graph.
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
||||
self.nodes.update(graph.nodes)
|
||||
self.edges.extend(graph.edges)
|
||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||
prefix = ""
|
||||
|
||||
def prefixed(id: str) -> str:
|
||||
return f"{prefix}:{id}" if prefix else id
|
||||
|
||||
# prefix each node
|
||||
self.nodes.update(
|
||||
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
|
||||
)
|
||||
# prefix each edge's source and target
|
||||
self.edges.extend(
|
||||
[
|
||||
Edge(
|
||||
prefixed(edge.source),
|
||||
prefixed(edge.target),
|
||||
edge.data,
|
||||
edge.conditional,
|
||||
)
|
||||
for edge in graph.edges
|
||||
]
|
||||
)
|
||||
# return (prefixed) first and last nodes of the subgraph
|
||||
first, last = graph.first_node(), graph.last_node()
|
||||
return (
|
||||
Node(prefixed(first.id), first.data) if first else None,
|
||||
Node(prefixed(last.id), last.data) if last else None,
|
||||
)
|
||||
|
||||
def first_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
|
@ -48,7 +48,7 @@ def draw_mermaid(
|
||||
if with_styles:
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
|
||||
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||
if first_node_label is not None:
|
||||
format_dict[first_node_label] = "{0}[{0}]:::startclass"
|
||||
if last_node_label is not None:
|
||||
@ -57,17 +57,24 @@ def draw_mermaid(
|
||||
# Add nodes to the graph
|
||||
for node in nodes.values():
|
||||
node_label = format_dict.get(node, format_dict[default_class_label]).format(
|
||||
_escape_node_label(node)
|
||||
_escape_node_label(node), _escape_node_label(node.split(":", 1)[-1])
|
||||
)
|
||||
mermaid_graph += f"\t{node_label};\n"
|
||||
|
||||
subgraph = ""
|
||||
# Add edges to the graph
|
||||
for edge in edges:
|
||||
src_prefix = edge.source.split(":")[0]
|
||||
tgt_prefix = edge.target.split(":")[0]
|
||||
# exit subgraph if source or target is not in the same subgraph
|
||||
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
|
||||
mermaid_graph += "\tend\n"
|
||||
subgraph = ""
|
||||
# enter subgraph if source and target are in the same subgraph
|
||||
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||
subgraph = src_prefix
|
||||
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
|
||||
if (
|
||||
adjusted_edge is None
|
||||
): # Ignore if it is connection between source and intermediate node
|
||||
continue
|
||||
|
||||
source, target = adjusted_edge
|
||||
|
||||
@ -96,6 +103,8 @@ def draw_mermaid(
|
||||
f"\t{_escape_node_label(source)}{edge_label}"
|
||||
f"{_escape_node_label(target)};\n"
|
||||
)
|
||||
if subgraph:
|
||||
mermaid_graph += "end\n"
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str:
|
||||
def _adjust_mermaid_edge(
|
||||
edge: Edge,
|
||||
nodes: Dict[str, str],
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
) -> Tuple[str, str]:
|
||||
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
|
||||
source_node_label = nodes.get(edge.source, edge.source)
|
||||
target_node_label = nodes.get(edge.target, edge.target)
|
||||
|
Loading…
Reference in New Issue
Block a user