mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 12:01:47 +00:00
core: Add concept of conditional edge to graph rendering (#20480)
- implement for mermaid, graphviz and ascii - this is to be used in langgraph
This commit is contained in:
parent
30b00090ef
commit
97b2191e99
@ -19,7 +19,6 @@ from typing import (
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.runnables.base import Runnable as RunnableType
|
from langchain_core.runnables.base import Runnable as RunnableType
|
||||||
@ -44,6 +43,7 @@ class Edge(NamedTuple):
|
|||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
data: Optional[str] = None
|
data: Optional[str] = None
|
||||||
|
conditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Node(NamedTuple):
|
class Node(NamedTuple):
|
||||||
@ -219,13 +219,21 @@ class Graph:
|
|||||||
if edge.source != node.id and edge.target != node.id
|
if edge.source != node.id and edge.target != node.id
|
||||||
]
|
]
|
||||||
|
|
||||||
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
|
def add_edge(
|
||||||
|
self,
|
||||||
|
source: Node,
|
||||||
|
target: Node,
|
||||||
|
data: Optional[str] = None,
|
||||||
|
conditional: bool = False,
|
||||||
|
) -> Edge:
|
||||||
"""Add an edge to the graph and return it."""
|
"""Add an edge to the graph and return it."""
|
||||||
if source.id not in self.nodes:
|
if source.id not in self.nodes:
|
||||||
raise ValueError(f"Source node {source.id} not in graph")
|
raise ValueError(f"Source node {source.id} not in graph")
|
||||||
if target.id not in self.nodes:
|
if target.id not in self.nodes:
|
||||||
raise ValueError(f"Target node {target.id} not in graph")
|
raise ValueError(f"Target node {target.id} not in graph")
|
||||||
edge = Edge(source=source.id, target=target.id, data=data)
|
edge = Edge(
|
||||||
|
source=source.id, target=target.id, data=data, conditional=conditional
|
||||||
|
)
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
return edge
|
return edge
|
||||||
|
|
||||||
@ -283,9 +291,11 @@ class Graph:
|
|||||||
self.remove_node(last_node)
|
self.remove_node(last_node)
|
||||||
|
|
||||||
def draw_ascii(self) -> str:
|
def draw_ascii(self) -> str:
|
||||||
|
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||||
|
|
||||||
return draw_ascii(
|
return draw_ascii(
|
||||||
{node.id: node_data_str(node) for node in self.nodes.values()},
|
{node.id: node_data_str(node) for node in self.nodes.values()},
|
||||||
[(edge.source, edge.target) for edge in self.edges],
|
self.edges,
|
||||||
)
|
)
|
||||||
|
|
||||||
def print_ascii(self) -> None:
|
def print_ascii(self) -> None:
|
||||||
|
@ -3,7 +3,9 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Mapping, Sequence, Tuple
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
from langchain_core.runnables.graph import Edge as LangEdge
|
||||||
|
|
||||||
|
|
||||||
class VertexViewer:
|
class VertexViewer:
|
||||||
@ -156,7 +158,7 @@ class AsciiCanvas:
|
|||||||
|
|
||||||
|
|
||||||
def _build_sugiyama_layout(
|
def _build_sugiyama_layout(
|
||||||
vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]
|
vertices: Mapping[str, str], edges: Sequence[LangEdge]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
try:
|
try:
|
||||||
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import]
|
from grandalf.graphs import Edge, Graph, Vertex # type: ignore[import]
|
||||||
@ -181,7 +183,7 @@ def _build_sugiyama_layout(
|
|||||||
#
|
#
|
||||||
|
|
||||||
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
|
vertices_ = {id: Vertex(f" {data} ") for id, data in vertices.items()}
|
||||||
edges_ = [Edge(vertices_[s], vertices_[e]) for s, e in edges]
|
edges_ = [Edge(vertices_[s], vertices_[e], data=cond) for s, e, _, cond in edges]
|
||||||
vertices_list = vertices_.values()
|
vertices_list = vertices_.values()
|
||||||
graph = Graph(vertices_list, edges_)
|
graph = Graph(vertices_list, edges_)
|
||||||
|
|
||||||
@ -209,7 +211,7 @@ def _build_sugiyama_layout(
|
|||||||
return sug
|
return sug
|
||||||
|
|
||||||
|
|
||||||
def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) -> str:
|
def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
||||||
"""Build a DAG and draw it in ASCII.
|
"""Build a DAG and draw it in ASCII.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -220,7 +222,6 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) ->
|
|||||||
str: ASCII representation
|
str: ASCII representation
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from dvc.dagascii import draw
|
|
||||||
>>> vertices = [1, 2, 3, 4]
|
>>> vertices = [1, 2, 3, 4]
|
||||||
>>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)]
|
>>> edges = [(1, 2), (2, 3), (2, 4), (1, 4)]
|
||||||
>>> print(draw(vertices, edges))
|
>>> print(draw(vertices, edges))
|
||||||
@ -287,7 +288,7 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[Tuple[str, str]]) ->
|
|||||||
assert end_x >= 0
|
assert end_x >= 0
|
||||||
assert end_y >= 0
|
assert end_y >= 0
|
||||||
|
|
||||||
canvas.line(start_x, start_y, end_x, end_y, "*")
|
canvas.line(start_x, start_y, end_x, end_y, "." if edge.data else "*")
|
||||||
|
|
||||||
for vertex in sug.g.sV:
|
for vertex in sug.g.sV:
|
||||||
# NOTE: moving boxes w/2 to the left
|
# NOTE: moving boxes w/2 to the left
|
||||||
|
@ -76,7 +76,13 @@ def draw_mermaid(
|
|||||||
for i in range(0, len(words), wrap_label_n_words)
|
for i in range(0, len(words), wrap_label_n_words)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if edge.conditional:
|
||||||
|
edge_label = f" -. {edge_data} .-> "
|
||||||
|
else:
|
||||||
edge_label = f" -- {edge_data} --> "
|
edge_label = f" -- {edge_data} --> "
|
||||||
|
else:
|
||||||
|
if edge.conditional:
|
||||||
|
edge_label = " -.-> "
|
||||||
else:
|
else:
|
||||||
edge_label = " --> "
|
edge_label = " --> "
|
||||||
mermaid_graph += (
|
mermaid_graph += (
|
||||||
|
@ -52,7 +52,12 @@ class PngDrawer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_edge(
|
def add_edge(
|
||||||
self, viz: Any, source: str, target: str, label: Optional[str] = None
|
self,
|
||||||
|
viz: Any,
|
||||||
|
source: str,
|
||||||
|
target: str,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
conditional: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
viz.add_edge(
|
viz.add_edge(
|
||||||
source,
|
source,
|
||||||
@ -60,6 +65,7 @@ class PngDrawer:
|
|||||||
label=self.get_edge_label(label) if label else "",
|
label=self.get_edge_label(label) if label else "",
|
||||||
fontsize=12,
|
fontsize=12,
|
||||||
fontname=self.fontname,
|
fontname=self.fontname,
|
||||||
|
style="dotted" if conditional else "solid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
||||||
@ -98,8 +104,8 @@ class PngDrawer:
|
|||||||
self.add_node(viz, node)
|
self.add_node(viz, node)
|
||||||
|
|
||||||
def add_edges(self, viz: Any, graph: Graph) -> None:
|
def add_edges(self, viz: Any, graph: Graph) -> None:
|
||||||
for start, end, label in graph.edges:
|
for start, end, label, cond in graph.edges:
|
||||||
self.add_edge(viz, start, end, label)
|
self.add_edge(viz, start, end, label, cond)
|
||||||
|
|
||||||
def update_styles(self, viz: Any, graph: Graph) -> None:
|
def update_styles(self, viz: Any, graph: Graph) -> None:
|
||||||
if first := graph.first_node():
|
if first := graph.first_node():
|
||||||
|
Loading…
Reference in New Issue
Block a user