core: Add concept of conditional edge to graph rendering (#20480)

- implement for mermaid, graphviz and ascii
- this is to be used in langgraph
This commit is contained in:
Nuno Campos 2024-04-15 13:49:06 -07:00 committed by GitHub
parent 30b00090ef
commit 97b2191e99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 15 deletions

View File

@ -19,7 +19,6 @@ from typing import (
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_ascii import draw_ascii
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType from langchain_core.runnables.base import Runnable as RunnableType
@ -44,6 +43,7 @@ class Edge(NamedTuple):
source: str source: str
target: str target: str
data: Optional[str] = None data: Optional[str] = None
conditional: bool = False
class Node(NamedTuple): class Node(NamedTuple):
@ -219,13 +219,21 @@ class Graph:
if edge.source != node.id and edge.target != node.id if edge.source != node.id and edge.target != node.id
] ]
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge: def add_edge(
self,
source: Node,
target: Node,
data: Optional[str] = None,
conditional: bool = False,
) -> Edge:
"""Add an edge to the graph and return it.""" """Add an edge to the graph and return it."""
if source.id not in self.nodes: if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph") raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes: if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph") raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id, data=data) edge = Edge(
source=source.id, target=target.id, data=data, conditional=conditional
)
self.edges.append(edge) self.edges.append(edge)
return edge return edge
@ -283,9 +291,11 @@ class Graph:
self.remove_node(last_node) self.remove_node(last_node)
def draw_ascii(self) -> str: def draw_ascii(self) -> str:
from langchain_core.runnables.graph_ascii import draw_ascii
return draw_ascii( return draw_ascii(
{node.id: node_data_str(node) for node in self.nodes.values()}, {node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges], self.edges,
) )
def print_ascii(self) -> None: def print_ascii(self) -> None:

View File

@ -3,7 +3,9 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
import math import math
import os import os
from typing import Any, Mapping, Sequence, Tuple from typing import Any, Mapping, Sequence
from langchain_core.runnables.graph import Edge as LangEdge
class VertexViewer: class VertexViewer:
@ -156,7 +158,7 @@ class AsciiCanvas:
def _build_sugiyama_layout( def _build_sugiyama_layout(
vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]] vertices: Mapping[str, str], edges: Sequence[LangEdge]
) -> Any: ) -> Any:
try: try:
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import] from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import]
@ -181,7 +183,7 @@ def _build_sugiyama_layout(
# #
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()} vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
edges_ = [Edge(vertices_[s], vertices_[e]) for s, e in edges] edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
vertices_list = vertices_.values() vertices_list = vertices_.values()
graph = Graph(vertices_list, edges_) graph = Graph(vertices_list, edges_)
@ -209,7 +211,7 @@ def _build_sugiyama_layout(
return sug return sug
def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str: def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
"""Build a DAG and draw it in ASCII. """Build a DAG and draw it in ASCII.
Args: Args:
@ -220,7 +222,6 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) ->
str: ASCII representation str: ASCII representation
Example: Example:
>>> from dvc.dagascii import draw
>>> vertices = [1, 2, 3, 4] >>> vertices = [1, 2, 3, 4]
>>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)] >>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)]
>>> print(draw(vertices, edges)) >>> print(draw(vertices, edges))
@ -287,7 +288,7 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) ->
assert end_x >= 0 assert end_x >= 0
assert end_y >= 0 assert end_y >= 0
canvas.line(start_x, start_y, end_x, end_y, "*") canvas.line(start_x, start_y, end_x, end_y, "." if edge.data else "*")
for vertex in sug.g.sV: for vertex in sug.g.sV:
# NOTE: moving boxes w/2 to the left # NOTE: moving boxes w/2 to the left

View File

@ -76,7 +76,13 @@ def draw_mermaid(
for i in range(0, len(words), wrap_label_n_words) for i in range(0, len(words), wrap_label_n_words)
] ]
) )
if edge.conditional:
edge_label = f" -. {edge_data} .-> "
else:
edge_label = f" -- {edge_data} --> " edge_label = f" -- {edge_data} --> "
else:
if edge.conditional:
edge_label = " -.-> "
else: else:
edge_label = " --> " edge_label = " --> "
mermaid_graph += ( mermaid_graph += (

View File

@ -52,7 +52,12 @@ class PngDrawer:
) )
def add_edge( def add_edge(
self, viz: Any, source: str, target: str, label: Optional[str] = None self,
viz: Any,
source: str,
target: str,
label: Optional[str] = None,
conditional: bool = False,
) -> None: ) -> None:
viz.add_edge( viz.add_edge(
source, source,
@ -60,6 +65,7 @@ class PngDrawer:
label=self.get_edge_label(label) if label else "", label=self.get_edge_label(label) if label else "",
fontsize=12, fontsize=12,
fontname=self.fontname, fontname=self.fontname,
style="dotted" if conditional else "solid",
) )
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
@ -98,8 +104,8 @@ class PngDrawer:
self.add_node(viz, node) self.add_node(viz, node)
def add_edges(self, viz: Any, graph: Graph) -> None: def add_edges(self, viz: Any, graph: Graph) -> None:
for start, end, label in graph.edges: for start, end, label, cond in graph.edges:
self.add_edge(viz, start, end, label) self.add_edge(viz, start, end, label, cond)
def update_styles(self, viz: Any, graph: Graph) -> None: def update_styles(self, viz: Any, graph: Graph) -> None:
if first := graph.first_node(): if first := graph.first_node():