diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 4f35205e853..bb4559f35a2 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -11,6 +11,7 @@ from typing import ( List, NamedTuple, Optional, + Protocol, Tuple, Type, TypedDict, @@ -25,6 +26,11 @@ if TYPE_CHECKING: from langchain_core.runnables.base import Runnable as RunnableType +class Stringifiable(Protocol): + def __str__(self) -> str: + ... + + class LabelsDict(TypedDict): """Dictionary of labels for nodes and edges in a graph.""" @@ -55,7 +61,7 @@ class Edge(NamedTuple): source: str target: str - data: Optional[str] = None + data: Optional[Stringifiable] = None conditional: bool = False @@ -253,7 +259,7 @@ class Graph: self, source: Node, target: Node, - data: Optional[str] = None, + data: Optional[Stringifiable] = None, conditional: bool = False, ) -> Edge: """Add an edge to the graph and return it.""" diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index b99c90f9a70..402b0ee4f31 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -81,7 +81,7 @@ def draw_mermaid( # Add BR every wrap_label_n_words words if edge.data is not None: edge_data = edge.data - words = edge_data.split() # Split the string into words + words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: edge_data = "
".join( diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 75983ad279c..0fbeec7e354 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -160,8 +160,8 @@ class PngDrawer: self.add_node(viz, node) def add_edges(self, viz: Any, graph: Graph) -> None: - for start, end, label, cond in graph.edges: - self.add_edge(viz, start, end, label, cond) + for start, end, data, cond in graph.edges: + self.add_edge(viz, start, end, str(data), cond) def update_styles(self, viz: Any, graph: Graph) -> None: if first := graph.first_node():