mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
core(mermaid): allow greater customization (#29939)
Adds greater style customization by allowing a custom frontmatter config. This allows to set a `theme` and `look` or to adjust theme by setting `themeVariables` Example: ```python node_colors = NodeStyles( default="fill:#e2e2e2,line-height:1.2,stroke:#616161", first="fill:#cfeab8,fill-opacity:0", last="fill:#eac3b8", ) frontmatter_config = { "config": { "theme": "neutral", "look": "handDrawn" } } graph.get_graph().draw_mermaid_png(node_colors=node_colors, frontmatter_config=frontmatter_config) ```  --------- Co-authored-by: vbarda <vadym@langchain.dev>
This commit is contained in:
parent
07823cd41c
commit
3933a4abc3
@ -563,6 +563,7 @@ class Graph:
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Draw the graph as a Mermaid syntax string.
|
||||
|
||||
@ -572,6 +573,22 @@ class Graph:
|
||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||
Defaults to 9.
|
||||
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||
Can be used to customize theme and styles. Will be converted to YAML and
|
||||
added to the beginning of the mermaid graph. Defaults to None.
|
||||
|
||||
See more here: https://mermaid.js.org/config/configuration.html.
|
||||
|
||||
Example:
|
||||
```python
|
||||
{
|
||||
"config": {
|
||||
"theme": "neutral",
|
||||
"look": "handDrawn",
|
||||
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
The Mermaid syntax string.
|
||||
@ -591,6 +608,7 @@ class Graph:
|
||||
curve_style=curve_style,
|
||||
node_styles=node_colors,
|
||||
wrap_label_n_words=wrap_label_n_words,
|
||||
frontmatter_config=frontmatter_config,
|
||||
)
|
||||
|
||||
def draw_mermaid_png(
|
||||
@ -603,6 +621,7 @@ class Graph:
|
||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||
background_color: str = "white",
|
||||
padding: int = 10,
|
||||
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
@ -617,6 +636,22 @@ class Graph:
|
||||
Defaults to MermaidDrawMethod.API.
|
||||
background_color: The color of the background. Defaults to "white".
|
||||
padding: The padding around the graph. Defaults to 10.
|
||||
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||
Can be used to customize theme and styles. Will be converted to YAML and
|
||||
added to the beginning of the mermaid graph. Defaults to None.
|
||||
|
||||
See more here: https://mermaid.js.org/config/configuration.html.
|
||||
|
||||
Example:
|
||||
```python
|
||||
{
|
||||
"config": {
|
||||
"theme": "neutral",
|
||||
"look": "handDrawn",
|
||||
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
@ -627,6 +662,7 @@ class Graph:
|
||||
curve_style=curve_style,
|
||||
node_colors=node_colors,
|
||||
wrap_label_n_words=wrap_label_n_words,
|
||||
frontmatter_config=frontmatter_config,
|
||||
)
|
||||
return draw_mermaid_png(
|
||||
mermaid_syntax=mermaid_syntax,
|
||||
|
@ -3,7 +3,9 @@ import base64
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core.runnables.graph import (
|
||||
CurveStyle,
|
||||
@ -26,6 +28,7 @@ def draw_mermaid(
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_styles: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data.
|
||||
|
||||
@ -43,15 +46,44 @@ def draw_mermaid(
|
||||
Defaults to NodeStyles().
|
||||
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
||||
Defaults to 9.
|
||||
frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config.
|
||||
Can be used to customize theme and styles. Will be converted to YAML and
|
||||
added to the beginning of the mermaid graph. Defaults to None.
|
||||
|
||||
See more here: https://mermaid.js.org/config/configuration.html.
|
||||
|
||||
Example:
|
||||
```python
|
||||
{
|
||||
"config": {
|
||||
"theme": "neutral",
|
||||
"look": "handDrawn",
|
||||
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
str: Mermaid graph syntax.
|
||||
"""
|
||||
# Initialize Mermaid graph configuration
|
||||
original_frontmatter_config = frontmatter_config or {}
|
||||
original_flowchart_config = original_frontmatter_config.get("config", {}).get(
|
||||
"flowchart", {}
|
||||
)
|
||||
frontmatter_config = {
|
||||
**original_frontmatter_config,
|
||||
"config": {
|
||||
**original_frontmatter_config.get("config", {}),
|
||||
"flowchart": {**original_flowchart_config, "curve": curve_style.value},
|
||||
},
|
||||
}
|
||||
|
||||
mermaid_graph = (
|
||||
(
|
||||
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
|
||||
f"}}}}}}%%\ngraph TD;\n"
|
||||
"---\n"
|
||||
+ yaml.dump(frontmatter_config, default_flow_style=False)
|
||||
+ "---\ngraph TD;\n"
|
||||
)
|
||||
if with_styles
|
||||
else "graph TD;\n"
|
||||
|
@ -1,7 +1,11 @@
|
||||
# serializer version: 1
|
||||
# name: test_double_nested_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
parent_1(parent_1)
|
||||
@ -28,7 +32,11 @@
|
||||
# ---
|
||||
# name: test_triple_nested_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
parent_1(parent_1)
|
||||
@ -71,6 +79,27 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_mermaid_frontmatter_config[mermaid]
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
look: handDrawn
|
||||
theme: neutral
|
||||
themeVariables:
|
||||
primaryColor: '#e2e2e2'
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
my_node([my_node]):::last
|
||||
__start__ --> my_node;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_sequence[ascii]
|
||||
'''
|
||||
+-------------+
|
||||
@ -104,7 +133,11 @@
|
||||
# ---
|
||||
# name: test_graph_sequence[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
PromptInput([PromptInput]):::first
|
||||
PromptTemplate(PromptTemplate)
|
||||
@ -1927,7 +1960,11 @@
|
||||
# ---
|
||||
# name: test_graph_sequence_map[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
PromptInput([PromptInput]):::first
|
||||
PromptTemplate(PromptTemplate)
|
||||
@ -1977,7 +2014,11 @@
|
||||
# ---
|
||||
# name: test_graph_single_runnable[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
StrOutputParserInput([StrOutputParserInput]):::first
|
||||
StrOutputParser(StrOutputParser)
|
||||
@ -1992,7 +2033,11 @@
|
||||
# ---
|
||||
# name: test_parallel_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
outer_1(outer_1)
|
||||
@ -2022,7 +2067,11 @@
|
||||
# ---
|
||||
# name: test_single_node_subgraph_mermaid[mermaid]
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
__end__([<p>__end__</p>]):::last
|
||||
|
@ -535,3 +535,28 @@ def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
|
||||
)
|
||||
graph = sequence.get_graph()
|
||||
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"__start__": Node(
|
||||
id="__start__", name="__start__", data=BaseModel, metadata=None
|
||||
),
|
||||
"my_node": Node(
|
||||
id="my_node", name="my_node", data=BaseModel, metadata=None
|
||||
),
|
||||
},
|
||||
edges=[
|
||||
Edge(source="__start__", target="my_node", data=None, conditional=False)
|
||||
],
|
||||
)
|
||||
assert graph.draw_mermaid(
|
||||
frontmatter_config={
|
||||
"config": {
|
||||
"theme": "neutral",
|
||||
"look": "handDrawn",
|
||||
"themeVariables": {"primaryColor": "#e2e2e2"},
|
||||
}
|
||||
}
|
||||
) == snapshot(name="mermaid")
|
||||
|
Loading…
Reference in New Issue
Block a user