mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
core(mermaid): allow greater customization (#29939)
Adds greater style customization by allowing a custom frontmatter config. This allows to set a `theme` and `look` or to adjust theme by setting `themeVariables` Example: ```python node_colors = NodeStyles( default="fill:#e2e2e2,line-height:1.2,stroke:#616161", first="fill:#cfeab8,fill-opacity:0", last="fill:#eac3b8", ) frontmatter_config = { "config": { "theme": "neutral", "look": "handDrawn" } } graph.get_graph().draw_mermaid_png(node_colors=node_colors, frontmatter_config=frontmatter_config) ```  --------- Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
parent
07823cd41c
commit
3933a4abc3
@ -563,6 +563,7 @@ class Graph:
|
|||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: Optional[NodeStyles] = None,
|
node_colors: Optional[NodeStyles] = None,
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
|
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Draw the graph as a Mermaid syntax string.
|
"""Draw the graph as a Mermaid syntax string.
|
||||||
|
|
||||||
@ -572,6 +573,22 @@ class Graph:
|
|||||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||||
Defaults to 9.
|
Defaults to 9.
|
||||||
|
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||||
|
Can be used to customize theme and styles. Will be converted to YAML and
|
||||||
|
added to the beginning of the mermaid graph. Defaults to None.
|
||||||
|
|
||||||
|
See more here: https://mermaid.js.org/config/configuration.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"theme": "neutral",
|
||||||
|
"look": "handDrawn",
|
||||||
|
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The Mermaid syntax string.
|
The Mermaid syntax string.
|
||||||
@ -591,6 +608,7 @@ class Graph:
|
|||||||
curve_style=curve_style,
|
curve_style=curve_style,
|
||||||
node_styles=node_colors,
|
node_styles=node_colors,
|
||||||
wrap_label_n_words=wrap_label_n_words,
|
wrap_label_n_words=wrap_label_n_words,
|
||||||
|
frontmatter_config=frontmatter_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def draw_mermaid_png(
|
def draw_mermaid_png(
|
||||||
@ -603,6 +621,7 @@ class Graph:
|
|||||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||||
background_color: str = "white",
|
background_color: str = "white",
|
||||||
padding: int = 10,
|
padding: int = 10,
|
||||||
|
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Draw the graph as a PNG image using Mermaid.
|
"""Draw the graph as a PNG image using Mermaid.
|
||||||
|
|
||||||
@ -617,6 +636,22 @@ class Graph:
|
|||||||
Defaults to MermaidDrawMethod.API.
|
Defaults to MermaidDrawMethod.API.
|
||||||
background_color: The color of the background. Defaults to "white".
|
background_color: The color of the background. Defaults to "white".
|
||||||
padding: The padding around the graph. Defaults to 10.
|
padding: The padding around the graph. Defaults to 10.
|
||||||
|
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||||
|
Can be used to customize theme and styles. Will be converted to YAML and
|
||||||
|
added to the beginning of the mermaid graph. Defaults to None.
|
||||||
|
|
||||||
|
See more here: https://mermaid.js.org/config/configuration.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"theme": "neutral",
|
||||||
|
"look": "handDrawn",
|
||||||
|
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The PNG image as bytes.
|
The PNG image as bytes.
|
||||||
@ -627,6 +662,7 @@ class Graph:
|
|||||||
curve_style=curve_style,
|
curve_style=curve_style,
|
||||||
node_colors=node_colors,
|
node_colors=node_colors,
|
||||||
wrap_label_n_words=wrap_label_n_words,
|
wrap_label_n_words=wrap_label_n_words,
|
||||||
|
frontmatter_config=frontmatter_config,
|
||||||
)
|
)
|
||||||
return draw_mermaid_png(
|
return draw_mermaid_png(
|
||||||
mermaid_syntax=mermaid_syntax,
|
mermaid_syntax=mermaid_syntax,
|
||||||
|
@ -3,7 +3,9 @@ import base64
|
|||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from langchain_core.runnables.graph import (
|
from langchain_core.runnables.graph import (
|
||||||
CurveStyle,
|
CurveStyle,
|
||||||
@ -26,6 +28,7 @@ def draw_mermaid(
|
|||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_styles: Optional[NodeStyles] = None,
|
node_styles: Optional[NodeStyles] = None,
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
|
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Draws a Mermaid graph using the provided graph data.
|
"""Draws a Mermaid graph using the provided graph data.
|
||||||
|
|
||||||
@ -43,15 +46,44 @@ def draw_mermaid(
|
|||||||
Defaults to NodeStyles().
|
Defaults to NodeStyles().
|
||||||
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
||||||
Defaults to 9.
|
Defaults to 9.
|
||||||
|
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||||
|
Can be used to customize theme and styles. Will be converted to YAML and
|
||||||
|
added to the beginning of the mermaid graph. Defaults to None.
|
||||||
|
|
||||||
|
See more here: https://mermaid.js.org/config/configuration.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"theme": "neutral",
|
||||||
|
"look": "handDrawn",
|
||||||
|
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Mermaid graph syntax.
|
str: Mermaid graph syntax.
|
||||||
"""
|
"""
|
||||||
# Initialize Mermaid graph configuration
|
# Initialize Mermaid graph configuration
|
||||||
|
original_frontmatter_config = frontmatter_config or {}
|
||||||
|
original_flowchart_config = original_frontmatter_config.get("config", {}).get(
|
||||||
|
"flowchart", {}
|
||||||
|
)
|
||||||
|
frontmatter_config = {
|
||||||
|
**original_frontmatter_config,
|
||||||
|
"config": {
|
||||||
|
**original_frontmatter_config.get("config", {}),
|
||||||
|
"flowchart": {**original_flowchart_config, "curve": curve_style.value},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
mermaid_graph = (
|
mermaid_graph = (
|
||||||
(
|
(
|
||||||
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
|
"---\n"
|
||||||
f"}}}}}}%%\ngraph TD;\n"
|
+ yaml.dump(frontmatter_config, default_flow_style=False)
|
||||||
|
+ "---\ngraph TD;\n"
|
||||||
)
|
)
|
||||||
if with_styles
|
if with_styles
|
||||||
else "graph TD;\n"
|
else "graph TD;\n"
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_double_nested_subgraph_mermaid[mermaid]
|
# name: test_double_nested_subgraph_mermaid[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
__start__([<p>__start__</p>]):::first
|
__start__([<p>__start__</p>]):::first
|
||||||
parent_1(parent_1)
|
parent_1(parent_1)
|
||||||
@ -28,7 +32,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_triple_nested_subgraph_mermaid[mermaid]
|
# name: test_triple_nested_subgraph_mermaid[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
__start__([<p>__start__</p>]):::first
|
__start__([<p>__start__</p>]):::first
|
||||||
parent_1(parent_1)
|
parent_1(parent_1)
|
||||||
@ -71,6 +79,27 @@
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_graph_mermaid_frontmatter_config[mermaid]
|
||||||
|
'''
|
||||||
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
look: handDrawn
|
||||||
|
theme: neutral
|
||||||
|
themeVariables:
|
||||||
|
primaryColor: '#e2e2e2'
|
||||||
|
---
|
||||||
|
graph TD;
|
||||||
|
__start__([<p>__start__</p>]):::first
|
||||||
|
my_node([my_node]):::last
|
||||||
|
__start__ --> my_node;
|
||||||
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
|
classDef first fill-opacity:0
|
||||||
|
classDef last fill:#bfb6fc
|
||||||
|
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
# name: test_graph_sequence[ascii]
|
# name: test_graph_sequence[ascii]
|
||||||
'''
|
'''
|
||||||
+-------------+
|
+-------------+
|
||||||
@ -104,7 +133,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_sequence[mermaid]
|
# name: test_graph_sequence[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
PromptInput([PromptInput]):::first
|
PromptInput([PromptInput]):::first
|
||||||
PromptTemplate(PromptTemplate)
|
PromptTemplate(PromptTemplate)
|
||||||
@ -1927,7 +1960,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_sequence_map[mermaid]
|
# name: test_graph_sequence_map[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
PromptInput([PromptInput]):::first
|
PromptInput([PromptInput]):::first
|
||||||
PromptTemplate(PromptTemplate)
|
PromptTemplate(PromptTemplate)
|
||||||
@ -1977,7 +2014,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_graph_single_runnable[mermaid]
|
# name: test_graph_single_runnable[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
StrOutputParserInput([StrOutputParserInput]):::first
|
StrOutputParserInput([StrOutputParserInput]):::first
|
||||||
StrOutputParser(StrOutputParser)
|
StrOutputParser(StrOutputParser)
|
||||||
@ -1992,7 +2033,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_parallel_subgraph_mermaid[mermaid]
|
# name: test_parallel_subgraph_mermaid[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
__start__([<p>__start__</p>]):::first
|
__start__([<p>__start__</p>]):::first
|
||||||
outer_1(outer_1)
|
outer_1(outer_1)
|
||||||
@ -2022,7 +2067,11 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_single_node_subgraph_mermaid[mermaid]
|
# name: test_single_node_subgraph_mermaid[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
---
|
||||||
|
config:
|
||||||
|
flowchart:
|
||||||
|
curve: linear
|
||||||
|
---
|
||||||
graph TD;
|
graph TD;
|
||||||
__start__([<p>__start__</p>]):::first
|
__start__([<p>__start__</p>]):::first
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
|
@ -535,3 +535,28 @@ def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
|
|||||||
)
|
)
|
||||||
graph = sequence.get_graph()
|
graph = sequence.get_graph()
|
||||||
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
|
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
|
||||||
|
graph = Graph(
|
||||||
|
nodes={
|
||||||
|
"__start__": Node(
|
||||||
|
id="__start__", name="__start__", data=BaseModel, metadata=None
|
||||||
|
),
|
||||||
|
"my_node": Node(
|
||||||
|
id="my_node", name="my_node", data=BaseModel, metadata=None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
edges=[
|
||||||
|
Edge(source="__start__", target="my_node", data=None, conditional=False)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert graph.draw_mermaid(
|
||||||
|
frontmatter_config={
|
||||||
|
"config": {
|
||||||
|
"theme": "neutral",
|
||||||
|
"look": "handDrawn",
|
||||||
|
"themeVariables": {"primaryColor": "#e2e2e2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
) == snapshot(name="mermaid")
|
||||||
|
Loading…
Reference in New Issue
Block a user