From 3933a4abc3fa7d7c179cb88ac42dbb112fad3ba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A1n=20Panella?= Date: Fri, 21 Mar 2025 16:25:26 -0600 Subject: [PATCH] core(mermaid): allow greater customization (#29939) Adds greater style customization by allowing a custom frontmatter config. This allows to set a `theme` and `look` or to adjust theme by setting `themeVariables` Example: ```python node_colors = NodeStyles( default="fill:#e2e2e2,line-height:1.2,stroke:#616161", first="fill:#cfeab8,fill-opacity:0", last="fill:#eac3b8", ) frontmatter_config = { "config": { "theme": "neutral", "look": "handDrawn" } } graph.get_graph().draw_mermaid_png(node_colors=node_colors, frontmatter_config=frontmatter_config) ``` ![image](https://github.com/user-attachments/assets/11b56d30-3be2-482f-8432-3ce704a09552) --------- Co-authored-by: vbarda --- libs/core/langchain_core/runnables/graph.py | 36 +++++++++++ .../langchain_core/runnables/graph_mermaid.py | 38 ++++++++++- .../runnables/__snapshots__/test_graph.ambr | 63 ++++++++++++++++--- .../tests/unit_tests/runnables/test_graph.py | 25 ++++++++ 4 files changed, 152 insertions(+), 10 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 99bcae5abf3..384c19d2385 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -563,6 +563,7 @@ class Graph: curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: Optional[NodeStyles] = None, wrap_label_n_words: int = 9, + frontmatter_config: Optional[dict[str, Any]] = None, ) -> str: """Draw the graph as a Mermaid syntax string. @@ -572,6 +573,22 @@ class Graph: node_colors: The colors of the nodes. Defaults to NodeStyles(). wrap_label_n_words: The number of words to wrap the node labels at. Defaults to 9. + frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. + Can be used to customize theme and styles. Will be converted to YAML and + added to the beginning of the mermaid graph. Defaults to None. + + See more here: https://mermaid.js.org/config/configuration.html. + + Example: + ```python + { + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": { "primaryColor": "#e2e2e2"}, + } + } + ``` Returns: The Mermaid syntax string. @@ -591,6 +608,7 @@ class Graph: curve_style=curve_style, node_styles=node_colors, wrap_label_n_words=wrap_label_n_words, + frontmatter_config=frontmatter_config, ) def draw_mermaid_png( @@ -603,6 +621,7 @@ class Graph: draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: str = "white", padding: int = 10, + frontmatter_config: Optional[dict[str, Any]] = None, ) -> bytes: """Draw the graph as a PNG image using Mermaid. @@ -617,6 +636,22 @@ class Graph: Defaults to MermaidDrawMethod.API. background_color: The color of the background. Defaults to "white". padding: The padding around the graph. Defaults to 10. + frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. + Can be used to customize theme and styles. Will be converted to YAML and + added to the beginning of the mermaid graph. Defaults to None. + + See more here: https://mermaid.js.org/config/configuration.html. + + Example: + ```python + { + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": { "primaryColor": "#e2e2e2"}, + } + } + ``` Returns: The PNG image as bytes. @@ -627,6 +662,7 @@ class Graph: curve_style=curve_style, node_colors=node_colors, wrap_label_n_words=wrap_label_n_words, + frontmatter_config=frontmatter_config, ) return draw_mermaid_png( mermaid_syntax=mermaid_syntax, diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index af4f806c1de..8c2ff5faf47 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -3,7 +3,9 @@ import base64 import re from dataclasses import asdict from pathlib import Path -from typing import Literal, Optional +from typing import Any, Literal, Optional + +import yaml from langchain_core.runnables.graph import ( CurveStyle, @@ -26,6 +28,7 @@ def draw_mermaid( curve_style: CurveStyle = CurveStyle.LINEAR, node_styles: Optional[NodeStyles] = None, wrap_label_n_words: int = 9, + frontmatter_config: Optional[dict[str, Any]] = None, ) -> str: """Draws a Mermaid graph using the provided graph data. @@ -43,15 +46,44 @@ def draw_mermaid( Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. + frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. + Can be used to customize theme and styles. Will be converted to YAML and + added to the beginning of the mermaid graph. Defaults to None. + + See more here: https://mermaid.js.org/config/configuration.html. + + Example: + ```python + { + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": { "primaryColor": "#e2e2e2"}, + } + } + ``` Returns: str: Mermaid graph syntax. """ # Initialize Mermaid graph configuration + original_frontmatter_config = frontmatter_config or {} + original_flowchart_config = original_frontmatter_config.get("config", {}).get( + "flowchart", {} + ) + frontmatter_config = { + **original_frontmatter_config, + "config": { + **original_frontmatter_config.get("config", {}), + "flowchart": {**original_flowchart_config, "curve": curve_style.value}, + }, + } + mermaid_graph = ( ( - f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" - f"}}}}}}%%\ngraph TD;\n" + "---\n" + + yaml.dump(frontmatter_config, default_flow_style=False) + + "---\ngraph TD;\n" ) if with_styles else "graph TD;\n" diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 2e4a19ce5c2..245b1a5b871 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1,7 +1,11 @@ # serializer version: 1 # name: test_double_nested_subgraph_mermaid[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; __start__([

__start__

]):::first parent_1(parent_1) @@ -28,7 +32,11 @@ # --- # name: test_triple_nested_subgraph_mermaid[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; __start__([

__start__

]):::first parent_1(parent_1) @@ -71,6 +79,27 @@ ''' # --- +# name: test_graph_mermaid_frontmatter_config[mermaid] + ''' + --- + config: + flowchart: + curve: linear + look: handDrawn + theme: neutral + themeVariables: + primaryColor: '#e2e2e2' + --- + graph TD; + __start__([

__start__

]):::first + my_node([my_node]):::last + __start__ --> my_node; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_graph_sequence[ascii] ''' +-------------+ @@ -104,7 +133,11 @@ # --- # name: test_graph_sequence[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; PromptInput([PromptInput]):::first PromptTemplate(PromptTemplate) @@ -1927,7 +1960,11 @@ # --- # name: test_graph_sequence_map[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; PromptInput([PromptInput]):::first PromptTemplate(PromptTemplate) @@ -1977,7 +2014,11 @@ # --- # name: test_graph_single_runnable[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; StrOutputParserInput([StrOutputParserInput]):::first StrOutputParser(StrOutputParser) @@ -1992,7 +2033,11 @@ # --- # name: test_parallel_subgraph_mermaid[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; __start__([

__start__

]):::first outer_1(outer_1) @@ -2022,7 +2067,11 @@ # --- # name: test_single_node_subgraph_mermaid[mermaid] ''' - %%{init: {'flowchart': {'curve': 'linear'}}}%% + --- + config: + flowchart: + curve: linear + --- graph TD; __start__([

__start__

]):::first __end__([

__end__

]):::last diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 6f822c1e7c2..870c4c76e90 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -535,3 +535,28 @@ def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None: ) graph = sequence.get_graph() assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid") + + +def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None: + graph = Graph( + nodes={ + "__start__": Node( + id="__start__", name="__start__", data=BaseModel, metadata=None + ), + "my_node": Node( + id="my_node", name="my_node", data=BaseModel, metadata=None + ), + }, + edges=[ + Edge(source="__start__", target="my_node", data=None, conditional=False) + ], + ) + assert graph.draw_mermaid( + frontmatter_config={ + "config": { + "theme": "neutral", + "look": "handDrawn", + "themeVariables": {"primaryColor": "#e2e2e2"}, + } + } + ) == snapshot(name="mermaid")