mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Better support for subgraphs in graph viz (#20840)
This commit is contained in:
parent
a9c7d47c03
commit
477eb1745c
@ -2409,8 +2409,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
step_graph.trim_first_node()
|
step_graph.trim_first_node()
|
||||||
if step is not self.last:
|
if step is not self.last:
|
||||||
step_graph.trim_last_node()
|
step_graph.trim_last_node()
|
||||||
graph.extend(step_graph)
|
step_first_node, _ = graph.extend(step_graph)
|
||||||
step_first_node = step_graph.first_node()
|
|
||||||
if not step_first_node:
|
if not step_first_node:
|
||||||
raise ValueError(f"Runnable {step} has no first node")
|
raise ValueError(f"Runnable {step} has no first node")
|
||||||
if current_last_node:
|
if current_last_node:
|
||||||
@ -3082,11 +3081,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
|||||||
if not step_graph:
|
if not step_graph:
|
||||||
graph.add_edge(input_node, output_node)
|
graph.add_edge(input_node, output_node)
|
||||||
else:
|
else:
|
||||||
graph.extend(step_graph)
|
step_first_node, step_last_node = graph.extend(step_graph)
|
||||||
step_first_node = step_graph.first_node()
|
|
||||||
if not step_first_node:
|
if not step_first_node:
|
||||||
raise ValueError(f"Runnable {step} has no first node")
|
raise ValueError(f"Runnable {step} has no first node")
|
||||||
step_last_node = step_graph.last_node()
|
|
||||||
if not step_last_node:
|
if not step_last_node:
|
||||||
raise ValueError(f"Runnable {step} has no last node")
|
raise ValueError(f"Runnable {step} has no last node")
|
||||||
graph.add_edge(input_node, step_first_node)
|
graph.add_edge(input_node, step_first_node)
|
||||||
@ -3779,11 +3776,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
if not dep_graph:
|
if not dep_graph:
|
||||||
graph.add_edge(input_node, output_node)
|
graph.add_edge(input_node, output_node)
|
||||||
else:
|
else:
|
||||||
graph.extend(dep_graph)
|
dep_first_node, dep_last_node = graph.extend(dep_graph)
|
||||||
dep_first_node = dep_graph.first_node()
|
|
||||||
if not dep_first_node:
|
if not dep_first_node:
|
||||||
raise ValueError(f"Runnable {dep} has no first node")
|
raise ValueError(f"Runnable {dep} has no first node")
|
||||||
dep_last_node = dep_graph.last_node()
|
|
||||||
if not dep_last_node:
|
if not dep_last_node:
|
||||||
raise ValueError(f"Runnable {dep} has no last node")
|
raise ValueError(f"Runnable {dep} has no last node")
|
||||||
graph.add_edge(input_node, dep_first_node)
|
graph.add_edge(input_node, dep_first_node)
|
||||||
|
@ -11,6 +11,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
@ -236,11 +237,39 @@ class Graph:
|
|||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
return edge
|
return edge
|
||||||
|
|
||||||
def extend(self, graph: Graph) -> None:
|
def extend(
|
||||||
|
self, graph: Graph, *, prefix: str = ""
|
||||||
|
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||||
"""Add all nodes and edges from another graph.
|
"""Add all nodes and edges from another graph.
|
||||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
||||||
self.nodes.update(graph.nodes)
|
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||||
self.edges.extend(graph.edges)
|
prefix = ""
|
||||||
|
|
||||||
|
def prefixed(id: str) -> str:
|
||||||
|
return f"{prefix}:{id}" if prefix else id
|
||||||
|
|
||||||
|
# prefix each node
|
||||||
|
self.nodes.update(
|
||||||
|
{prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()}
|
||||||
|
)
|
||||||
|
# prefix each edge's source and target
|
||||||
|
self.edges.extend(
|
||||||
|
[
|
||||||
|
Edge(
|
||||||
|
prefixed(edge.source),
|
||||||
|
prefixed(edge.target),
|
||||||
|
edge.data,
|
||||||
|
edge.conditional,
|
||||||
|
)
|
||||||
|
for edge in graph.edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# return (prefixed) first and last nodes of the subgraph
|
||||||
|
first, last = graph.first_node(), graph.last_node()
|
||||||
|
return (
|
||||||
|
Node(prefixed(first.id), first.data) if first else None,
|
||||||
|
Node(prefixed(last.id), last.data) if last else None,
|
||||||
|
)
|
||||||
|
|
||||||
def first_node(self) -> Optional[Node]:
|
def first_node(self) -> Optional[Node]:
|
||||||
"""Find the single node that is not a target of any edge.
|
"""Find the single node that is not a target of any edge.
|
||||||
|
@ -48,7 +48,7 @@ def draw_mermaid(
|
|||||||
if with_styles:
|
if with_styles:
|
||||||
# Node formatting templates
|
# Node formatting templates
|
||||||
default_class_label = "default"
|
default_class_label = "default"
|
||||||
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
|
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||||
if first_node_label is not None:
|
if first_node_label is not None:
|
||||||
format_dict[first_node_label] = "{0}[{0}]:::startclass"
|
format_dict[first_node_label] = "{0}[{0}]:::startclass"
|
||||||
if last_node_label is not None:
|
if last_node_label is not None:
|
||||||
@ -57,17 +57,24 @@ def draw_mermaid(
|
|||||||
# Add nodes to the graph
|
# Add nodes to the graph
|
||||||
for node in nodes.values():
|
for node in nodes.values():
|
||||||
node_label = format_dict.get(node, format_dict[default_class_label]).format(
|
node_label = format_dict.get(node, format_dict[default_class_label]).format(
|
||||||
_escape_node_label(node)
|
_escape_node_label(node), _escape_node_label(node.split(":", 1)[-1])
|
||||||
)
|
)
|
||||||
mermaid_graph += f"\t{node_label};\n"
|
mermaid_graph += f"\t{node_label};\n"
|
||||||
|
|
||||||
|
subgraph = ""
|
||||||
# Add edges to the graph
|
# Add edges to the graph
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
|
src_prefix = edge.source.split(":")[0]
|
||||||
|
tgt_prefix = edge.target.split(":")[0]
|
||||||
|
# exit subgraph if source or target is not in the same subgraph
|
||||||
|
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
|
||||||
|
mermaid_graph += "\tend\n"
|
||||||
|
subgraph = ""
|
||||||
|
# enter subgraph if source and target are in the same subgraph
|
||||||
|
if not subgraph and src_prefix and src_prefix == tgt_prefix:
|
||||||
|
mermaid_graph += f"\tsubgraph {src_prefix}\n"
|
||||||
|
subgraph = src_prefix
|
||||||
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
|
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
|
||||||
if (
|
|
||||||
adjusted_edge is None
|
|
||||||
): # Ignore if it is connection between source and intermediate node
|
|
||||||
continue
|
|
||||||
|
|
||||||
source, target = adjusted_edge
|
source, target = adjusted_edge
|
||||||
|
|
||||||
@ -96,6 +103,8 @@ def draw_mermaid(
|
|||||||
f"\t{_escape_node_label(source)}{edge_label}"
|
f"\t{_escape_node_label(source)}{edge_label}"
|
||||||
f"{_escape_node_label(target)};\n"
|
f"{_escape_node_label(target)};\n"
|
||||||
)
|
)
|
||||||
|
if subgraph:
|
||||||
|
mermaid_graph += "end\n"
|
||||||
|
|
||||||
# Add custom styles for nodes
|
# Add custom styles for nodes
|
||||||
if with_styles:
|
if with_styles:
|
||||||
@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str:
|
|||||||
def _adjust_mermaid_edge(
|
def _adjust_mermaid_edge(
|
||||||
edge: Edge,
|
edge: Edge,
|
||||||
nodes: Dict[str, str],
|
nodes: Dict[str, str],
|
||||||
) -> Optional[Tuple[str, str]]:
|
) -> Tuple[str, str]:
|
||||||
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
|
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
|
||||||
source_node_label = nodes.get(edge.source, edge.source)
|
source_node_label = nodes.get(edge.source, edge.source)
|
||||||
target_node_label = nodes.get(edge.target, edge.target)
|
target_node_label = nodes.get(edge.target, edge.target)
|
||||||
|
Loading…
Reference in New Issue
Block a user