mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
core[minor]: handle boolean data in draw_mermaid (#23135)
This change should address graph rendering issues for edges with boolean data Example from langgraph: ```python from typing import Annotated, TypedDict from langchain_core.messages import AnyMessage from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages class State(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def branch(state: State) -> bool: return 1 + 1 == 3 graph_builder = StateGraph(State) graph_builder.add_node("foo", lambda state: {"messages": [("ai", "foo")]}) graph_builder.add_node("bar", lambda state: {"messages": [("ai", "bar")]}) graph_builder.add_conditional_edges( START, branch, path_map={True: "foo", False: "bar"}, then=END, ) app = graph_builder.compile() print(app.get_graph().draw_mermaid()) ``` Previous behavior: ```python AttributeError: 'bool' object has no attribute 'split' ``` Current behavior: ```python %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; __start__[__start__]:::startclass; __end__[__end__]:::endclass; foo([foo]):::otherclass; bar([bar]):::otherclass; __start__ -. ('a',) .-> foo; foo --> __end__; __start__ -. ('b',) .-> bar; bar --> __end__; classDef startclass fill:#ffdfba; classDef endclass fill:#baffc9; classDef otherclass fill:#fad7de; ```
This commit is contained in:
parent
093ae04d58
commit
b483bf5095
@ -11,6 +11,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
|
Protocol,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
@ -25,6 +26,11 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.runnables.base import Runnable as RunnableType
|
from langchain_core.runnables.base import Runnable as RunnableType
|
||||||
|
|
||||||
|
|
||||||
|
class Stringifiable(Protocol):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class LabelsDict(TypedDict):
|
class LabelsDict(TypedDict):
|
||||||
"""Dictionary of labels for nodes and edges in a graph."""
|
"""Dictionary of labels for nodes and edges in a graph."""
|
||||||
|
|
||||||
@ -55,7 +61,7 @@ class Edge(NamedTuple):
|
|||||||
|
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
data: Optional[str] = None
|
data: Optional[Stringifiable] = None
|
||||||
conditional: bool = False
|
conditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -253,7 +259,7 @@ class Graph:
|
|||||||
self,
|
self,
|
||||||
source: Node,
|
source: Node,
|
||||||
target: Node,
|
target: Node,
|
||||||
data: Optional[str] = None,
|
data: Optional[Stringifiable] = None,
|
||||||
conditional: bool = False,
|
conditional: bool = False,
|
||||||
) -> Edge:
|
) -> Edge:
|
||||||
"""Add an edge to the graph and return it."""
|
"""Add an edge to the graph and return it."""
|
||||||
|
@ -81,7 +81,7 @@ def draw_mermaid(
|
|||||||
# Add BR every wrap_label_n_words words
|
# Add BR every wrap_label_n_words words
|
||||||
if edge.data is not None:
|
if edge.data is not None:
|
||||||
edge_data = edge.data
|
edge_data = edge.data
|
||||||
words = edge_data.split() # Split the string into words
|
words = str(edge_data).split() # Split the string into words
|
||||||
# Group words into chunks of wrap_label_n_words size
|
# Group words into chunks of wrap_label_n_words size
|
||||||
if len(words) > wrap_label_n_words:
|
if len(words) > wrap_label_n_words:
|
||||||
edge_data = "<br>".join(
|
edge_data = "<br>".join(
|
||||||
|
@ -160,8 +160,8 @@ class PngDrawer:
|
|||||||
self.add_node(viz, node)
|
self.add_node(viz, node)
|
||||||
|
|
||||||
def add_edges(self, viz: Any, graph: Graph) -> None:
|
def add_edges(self, viz: Any, graph: Graph) -> None:
|
||||||
for start, end, label, cond in graph.edges:
|
for start, end, data, cond in graph.edges:
|
||||||
self.add_edge(viz, start, end, label, cond)
|
self.add_edge(viz, start, end, str(data), cond)
|
||||||
|
|
||||||
def update_styles(self, viz: Any, graph: Graph) -> None:
|
def update_styles(self, viz: Any, graph: Graph) -> None:
|
||||||
if first := graph.first_node():
|
if first := graph.first_node():
|
||||||
|
Loading…
Reference in New Issue
Block a user