core(mermaid): allow greater customization (#29939)

Adds greater style customization by allowing a custom frontmatter
config. This allows to set a `theme` and `look` or to adjust theme by
setting `themeVariables`

Example:

```python

node_colors = NodeStyles(
    default="fill:#e2e2e2,line-height:1.2,stroke:#616161",
    first="fill:#cfeab8,fill-opacity:0",
    last="fill:#eac3b8",
)

frontmatter_config = {
    "config": {
        "theme": "neutral",
        "look": "handDrawn"
    }
}

graph.get_graph().draw_mermaid_png(node_colors=node_colors, frontmatter_config=frontmatter_config)
```


![image](https://github.com/user-attachments/assets/11b56d30-3be2-482f-8432-3ce704a09552)

---------

Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
Adrián Panella 2025-03-21 16:25:26 -06:00 committed by GitHub
parent 07823cd41c
commit 3933a4abc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 152 additions and 10 deletions

View File

@ -563,6 +563,7 @@ class Graph:
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: Optional[NodeStyles] = None, node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> str: ) -> str:
"""Draw the graph as a Mermaid syntax string. """Draw the graph as a Mermaid syntax string.
@ -572,6 +573,22 @@ class Graph:
node_colors: The colors of the nodes. Defaults to NodeStyles(). node_colors: The colors of the nodes. Defaults to NodeStyles().
wrap_label_n_words: The number of words to wrap the node labels at. wrap_label_n_words: The number of words to wrap the node labels at.
Defaults to 9. Defaults to 9.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns: Returns:
The Mermaid syntax string. The Mermaid syntax string.
@ -591,6 +608,7 @@ class Graph:
curve_style=curve_style, curve_style=curve_style,
node_styles=node_colors, node_styles=node_colors,
wrap_label_n_words=wrap_label_n_words, wrap_label_n_words=wrap_label_n_words,
frontmatter_config=frontmatter_config,
) )
def draw_mermaid_png( def draw_mermaid_png(
@ -603,6 +621,7 @@ class Graph:
draw_method: MermaidDrawMethod = MermaidDrawMethod.API, draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white", background_color: str = "white",
padding: int = 10, padding: int = 10,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> bytes: ) -> bytes:
"""Draw the graph as a PNG image using Mermaid. """Draw the graph as a PNG image using Mermaid.
@ -617,6 +636,22 @@ class Graph:
Defaults to MermaidDrawMethod.API. Defaults to MermaidDrawMethod.API.
background_color: The color of the background. Defaults to "white". background_color: The color of the background. Defaults to "white".
padding: The padding around the graph. Defaults to 10. padding: The padding around the graph. Defaults to 10.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns: Returns:
The PNG image as bytes. The PNG image as bytes.
@ -627,6 +662,7 @@ class Graph:
curve_style=curve_style, curve_style=curve_style,
node_colors=node_colors, node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words, wrap_label_n_words=wrap_label_n_words,
frontmatter_config=frontmatter_config,
) )
return draw_mermaid_png( return draw_mermaid_png(
mermaid_syntax=mermaid_syntax, mermaid_syntax=mermaid_syntax,

View File

@ -3,7 +3,9 @@ import base64
import re import re
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path from pathlib import Path
from typing import Literal, Optional from typing import Any, Literal, Optional
import yaml
from langchain_core.runnables.graph import ( from langchain_core.runnables.graph import (
CurveStyle, CurveStyle,
@ -26,6 +28,7 @@ def draw_mermaid(
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: Optional[NodeStyles] = None, node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> str: ) -> str:
"""Draws a Mermaid graph using the provided graph data. """Draws a Mermaid graph using the provided graph data.
@ -43,15 +46,44 @@ def draw_mermaid(
Defaults to NodeStyles(). Defaults to NodeStyles().
wrap_label_n_words (int, optional): Words to wrap the edge labels. wrap_label_n_words (int, optional): Words to wrap the edge labels.
Defaults to 9. Defaults to 9.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns: Returns:
str: Mermaid graph syntax. str: Mermaid graph syntax.
""" """
# Initialize Mermaid graph configuration # Initialize Mermaid graph configuration
original_frontmatter_config = frontmatter_config or {}
original_flowchart_config = original_frontmatter_config.get("config", {}).get(
"flowchart", {}
)
frontmatter_config = {
**original_frontmatter_config,
"config": {
**original_frontmatter_config.get("config", {}),
"flowchart": {**original_flowchart_config, "curve": curve_style.value},
},
}
mermaid_graph = ( mermaid_graph = (
( (
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" "---\n"
f"}}}}}}%%\ngraph TD;\n" + yaml.dump(frontmatter_config, default_flow_style=False)
+ "---\ngraph TD;\n"
) )
if with_styles if with_styles
else "graph TD;\n" else "graph TD;\n"

View File

@ -1,7 +1,11 @@
# serializer version: 1 # serializer version: 1
# name: test_double_nested_subgraph_mermaid[mermaid] # name: test_double_nested_subgraph_mermaid[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
__start__([<p>__start__</p>]):::first __start__([<p>__start__</p>]):::first
parent_1(parent_1) parent_1(parent_1)
@ -28,7 +32,11 @@
# --- # ---
# name: test_triple_nested_subgraph_mermaid[mermaid] # name: test_triple_nested_subgraph_mermaid[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
__start__([<p>__start__</p>]):::first __start__([<p>__start__</p>]):::first
parent_1(parent_1) parent_1(parent_1)
@ -71,6 +79,27 @@
''' '''
# --- # ---
# name: test_graph_mermaid_frontmatter_config[mermaid]
'''
---
config:
flowchart:
curve: linear
look: handDrawn
theme: neutral
themeVariables:
primaryColor: '#e2e2e2'
---
graph TD;
__start__([<p>__start__</p>]):::first
my_node([my_node]):::last
__start__ --> my_node;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_graph_sequence[ascii] # name: test_graph_sequence[ascii]
''' '''
+-------------+ +-------------+
@ -104,7 +133,11 @@
# --- # ---
# name: test_graph_sequence[mermaid] # name: test_graph_sequence[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
PromptInput([PromptInput]):::first PromptInput([PromptInput]):::first
PromptTemplate(PromptTemplate) PromptTemplate(PromptTemplate)
@ -1927,7 +1960,11 @@
# --- # ---
# name: test_graph_sequence_map[mermaid] # name: test_graph_sequence_map[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
PromptInput([PromptInput]):::first PromptInput([PromptInput]):::first
PromptTemplate(PromptTemplate) PromptTemplate(PromptTemplate)
@ -1977,7 +2014,11 @@
# --- # ---
# name: test_graph_single_runnable[mermaid] # name: test_graph_single_runnable[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
StrOutputParserInput([StrOutputParserInput]):::first StrOutputParserInput([StrOutputParserInput]):::first
StrOutputParser(StrOutputParser) StrOutputParser(StrOutputParser)
@ -1992,7 +2033,11 @@
# --- # ---
# name: test_parallel_subgraph_mermaid[mermaid] # name: test_parallel_subgraph_mermaid[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
__start__([<p>__start__</p>]):::first __start__([<p>__start__</p>]):::first
outer_1(outer_1) outer_1(outer_1)
@ -2022,7 +2067,11 @@
# --- # ---
# name: test_single_node_subgraph_mermaid[mermaid] # name: test_single_node_subgraph_mermaid[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% ---
config:
flowchart:
curve: linear
---
graph TD; graph TD;
__start__([<p>__start__</p>]):::first __start__([<p>__start__</p>]):::first
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last

View File

@ -535,3 +535,28 @@ def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
) )
graph = sequence.get_graph() graph = sequence.get_graph()
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid") assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
graph = Graph(
nodes={
"__start__": Node(
id="__start__", name="__start__", data=BaseModel, metadata=None
),
"my_node": Node(
id="my_node", name="my_node", data=BaseModel, metadata=None
),
},
edges=[
Edge(source="__start__", target="my_node", data=None, conditional=False)
],
)
assert graph.draw_mermaid(
frontmatter_config={
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": {"primaryColor": "#e2e2e2"},
}
}
) == snapshot(name="mermaid")