core(mermaid): allow greater customization (#29939)

Adds greater style customization by allowing a custom frontmatter
config. This allows to set a `theme` and `look` or to adjust theme by
setting `themeVariables`

Example:

```python

node_colors = NodeStyles(
    default="fill:#e2e2e2,line-height:1.2,stroke:#616161",
    first="fill:#cfeab8,fill-opacity:0",
    last="fill:#eac3b8",
)

frontmatter_config = {
    "config": {
        "theme": "neutral",
        "look": "handDrawn"
    }
}

graph.get_graph().draw_mermaid_png(node_colors=node_colors, frontmatter_config=frontmatter_config)
```


![image](https://github.com/user-attachments/assets/11b56d30-3be2-482f-8432-3ce704a09552)

---------

Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
Adrián Panella 2025-03-21 16:25:26 -06:00 committed by GitHub
parent 07823cd41c
commit 3933a4abc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 152 additions and 10 deletions

View File

@ -563,6 +563,7 @@ class Graph:
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> str:
"""Draw the graph as a Mermaid syntax string.
@ -572,6 +573,22 @@ class Graph:
node_colors: The colors of the nodes. Defaults to NodeStyles().
wrap_label_n_words: The number of words to wrap the node labels at.
Defaults to 9.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns:
The Mermaid syntax string.
@ -591,6 +608,7 @@ class Graph:
curve_style=curve_style,
node_styles=node_colors,
wrap_label_n_words=wrap_label_n_words,
frontmatter_config=frontmatter_config,
)
def draw_mermaid_png(
@ -603,6 +621,7 @@ class Graph:
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white",
padding: int = 10,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid.
@ -617,6 +636,22 @@ class Graph:
Defaults to MermaidDrawMethod.API.
background_color: The color of the background. Defaults to "white".
padding: The padding around the graph. Defaults to 10.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns:
The PNG image as bytes.
@ -627,6 +662,7 @@ class Graph:
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
frontmatter_config=frontmatter_config,
)
return draw_mermaid_png(
mermaid_syntax=mermaid_syntax,

View File

@ -3,7 +3,9 @@ import base64
import re
from dataclasses import asdict
from pathlib import Path
from typing import Literal, Optional
from typing import Any, Literal, Optional
import yaml
from langchain_core.runnables.graph import (
CurveStyle,
@ -26,6 +28,7 @@ def draw_mermaid(
curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
frontmatter_config: Optional[dict[str, Any]] = None,
) -> str:
"""Draws a Mermaid graph using the provided graph data.
@ -43,15 +46,44 @@ def draw_mermaid(
Defaults to NodeStyles().
wrap_label_n_words (int, optional): Words to wrap the edge labels.
Defaults to 9.
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
Can be used to customize theme and styles. Will be converted to YAML and
added to the beginning of the mermaid graph. Defaults to None.
See more here: https://mermaid.js.org/config/configuration.html.
Example:
```python
{
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
```
Returns:
str: Mermaid graph syntax.
"""
# Initialize Mermaid graph configuration
original_frontmatter_config = frontmatter_config or {}
original_flowchart_config = original_frontmatter_config.get("config", {}).get(
"flowchart", {}
)
frontmatter_config = {
**original_frontmatter_config,
"config": {
**original_frontmatter_config.get("config", {}),
"flowchart": {**original_flowchart_config, "curve": curve_style.value},
},
}
mermaid_graph = (
(
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n"
"---\n"
+ yaml.dump(frontmatter_config, default_flow_style=False)
+ "---\ngraph TD;\n"
)
if with_styles
else "graph TD;\n"

View File

@ -1,7 +1,11 @@
# serializer version: 1
# name: test_double_nested_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
parent_1(parent_1)
@ -28,7 +32,11 @@
# ---
# name: test_triple_nested_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
parent_1(parent_1)
@ -71,6 +79,27 @@
'''
# ---
# name: test_graph_mermaid_frontmatter_config[mermaid]
'''
---
config:
flowchart:
curve: linear
look: handDrawn
theme: neutral
themeVariables:
primaryColor: '#e2e2e2'
---
graph TD;
__start__([<p>__start__</p>]):::first
my_node([my_node]):::last
__start__ --> my_node;
classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0
classDef last fill:#bfb6fc
'''
# ---
# name: test_graph_sequence[ascii]
'''
+-------------+
@ -104,7 +133,11 @@
# ---
# name: test_graph_sequence[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
PromptInput([PromptInput]):::first
PromptTemplate(PromptTemplate)
@ -1927,7 +1960,11 @@
# ---
# name: test_graph_sequence_map[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
PromptInput([PromptInput]):::first
PromptTemplate(PromptTemplate)
@ -1977,7 +2014,11 @@
# ---
# name: test_graph_single_runnable[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
StrOutputParserInput([StrOutputParserInput]):::first
StrOutputParser(StrOutputParser)
@ -1992,7 +2033,11 @@
# ---
# name: test_parallel_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
outer_1(outer_1)
@ -2022,7 +2067,11 @@
# ---
# name: test_single_node_subgraph_mermaid[mermaid]
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
---
config:
flowchart:
curve: linear
---
graph TD;
__start__([<p>__start__</p>]):::first
__end__([<p>__end__</p>]):::last

View File

@ -535,3 +535,28 @@ def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
)
graph = sequence.get_graph()
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
graph = Graph(
nodes={
"__start__": Node(
id="__start__", name="__start__", data=BaseModel, metadata=None
),
"my_node": Node(
id="my_node", name="my_node", data=BaseModel, metadata=None
),
},
edges=[
Edge(source="__start__", target="my_node", data=None, conditional=False)
],
)
assert graph.draw_mermaid(
frontmatter_config={
"config": {
"theme": "neutral",
"look": "handDrawn",
"themeVariables": {"primaryColor": "#e2e2e2"},
}
}
) == snapshot(name="mermaid")