From b483bf509520fafb3b5618717de70e1edf5f2ba7 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Tue, 18 Jun 2024 16:15:42 -0400 Subject: [PATCH] core[minor]: handle boolean data in draw_mermaid (#23135) This change should address graph rendering issues for edges with boolean data Example from langgraph: ```python from typing import Annotated, TypedDict from langchain_core.messages import AnyMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages class State(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def branch(state: State) -> bool: return 1 + 1 == 3 graph_builder = StateGraph(State) graph_builder.add_node("foo", lambda state: {"messages": [("ai", "foo")]}) graph_builder.add_node("bar", lambda state: {"messages": [("ai", "bar")]}) graph_builder.add_conditional_edges( START, branch, path_map={True: "foo", False: "bar"}, then=END, ) app = graph_builder.compile() print(app.get_graph().draw_mermaid()) ``` Previous behavior: ```python AttributeError: 'bool' object has no attribute 'split' ``` Current behavior: ```python %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; __start__[__start__]:::startclass; __end__[__end__]:::endclass; foo([foo]):::otherclass; bar([bar]):::otherclass; __start__ -. ('a',) .-> foo; foo --> __end__; __start__ -. ('b',) .-> bar; bar --> __end__; classDef startclass fill:#ffdfba; classDef endclass fill:#baffc9; classDef otherclass fill:#fad7de; ``` --- libs/core/langchain_core/runnables/graph.py | 10 ++++++++-- libs/core/langchain_core/runnables/graph_mermaid.py | 2 +- libs/core/langchain_core/runnables/graph_png.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 4f35205e853..bb4559f35a2 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -11,6 +11,7 @@ from typing import ( List, NamedTuple, Optional, + Protocol, Tuple, Type, TypedDict, @@ -25,6 +26,11 @@ if TYPE_CHECKING: from langchain_core.runnables.base import Runnable as RunnableType +class Stringifiable(Protocol): + def __str__(self) -> str: + ... + + class LabelsDict(TypedDict): """Dictionary of labels for nodes and edges in a graph.""" @@ -55,7 +61,7 @@ class Edge(NamedTuple): source: str target: str - data: Optional[str] = None + data: Optional[Stringifiable] = None conditional: bool = False @@ -253,7 +259,7 @@ class Graph: self, source: Node, target: Node, - data: Optional[str] = None, + data: Optional[Stringifiable] = None, conditional: bool = False, ) -> Edge: """Add an edge to the graph and return it.""" diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index b99c90f9a70..402b0ee4f31 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -81,7 +81,7 @@ def draw_mermaid( # Add BR every wrap_label_n_words words if edge.data is not None: edge_data = edge.data - words = edge_data.split() # Split the string into words + words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: edge_data = "
".join( diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 75983ad279c..0fbeec7e354 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -160,8 +160,8 @@ class PngDrawer: self.add_node(viz, node) def add_edges(self, viz: Any, graph: Graph) -> None: - for start, end, label, cond in graph.edges: - self.add_edge(viz, start, end, label, cond) + for start, end, data, cond in graph.edges: + self.add_edge(viz, start, end, str(data), cond) def update_styles(self, viz: Any, graph: Graph) -> None: if first := graph.first_node():