feat(core): allow custom Mermaid URL (#32831)

- **Description:** Currently,
`langchain_core.runnables.graph_mermaid.py` is hardcoded to use
mermaid.ink to render graph diagrams. It would be nice to allow users to
specify a custom URL, e.g. for self-hosted instances of the Mermaid
server.
- **Issue:** [Langchain Forum: allow custom mermaid API
URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472)
  - **Dependencies:** None

- [X] **Add tests and docs**: Added unit tests using mock requests.
- [X] **Lint and test**: Run `make format`, `make lint` and `make test`.

Minimal example using the feature:

```python
import os
import operator
from pathlib import Path
from typing import Any, Annotated, TypedDict

from langgraph.graph import StateGraph

class State(TypedDict):
    messages: Annotated[list[dict[str, Any]], operator.add]

def hello_node(state: State) -> State:
    return {"messages": [{"role": "assistant", "content": "pong!"}]}

builder = StateGraph(State)
builder.add_node("hello_node", hello_node)
builder.add_edge("__start__", "hello_node")
builder.add_edge("hello_node", "__end__")

graph = builder.compile()

# Run graph
output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]})

# Draw graph
Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink"))
```

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Daniel Barker
2025-09-10 16:14:50 -05:00
committed by GitHub
parent 38001699d5
commit 25c34bd9b2
3 changed files with 110 additions and 7 deletions

View File

@@ -614,7 +614,6 @@ class Graph:
Returns:
The Mermaid syntax string.
"""
# Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
@@ -648,6 +647,7 @@ class Graph:
max_retries: int = 1,
retry_delay: float = 1.0,
frontmatter_config: Optional[dict[str, Any]] = None,
base_url: Optional[str] = None,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid.
@@ -683,6 +683,8 @@ class Graph:
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
base_url: The base URL of the Mermaid server for rendering via API.
Defaults to None.
Returns:
The PNG image as bytes.
@@ -707,6 +709,7 @@ class Graph:
padding=padding,
max_retries=max_retries,
retry_delay=retry_delay,
base_url=base_url,
)

View File

@@ -277,6 +277,7 @@ def draw_mermaid_png(
padding: int = 10,
max_retries: int = 1,
retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax.
@@ -293,6 +294,8 @@ def draw_mermaid_png(
Defaults to 1.
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
Defaults to 1.0.
base_url (str, optional): Base URL for the Mermaid.ink API.
Defaults to None.
Returns:
bytes: PNG image bytes.
@@ -313,6 +316,7 @@ def draw_mermaid_png(
background_color=background_color,
max_retries=max_retries,
retry_delay=retry_delay,
base_url=base_url,
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -404,8 +408,12 @@ def _render_mermaid_using_api(
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
max_retries: int = 1,
retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API."""
# Defaults to using the public mermaid.ink server.
base_url = base_url if base_url is not None else "https://mermaid.ink"
if not _HAS_REQUESTS:
msg = (
"Install the `requests` module to use the Mermaid.INK API: "
@@ -425,7 +433,7 @@ def _render_mermaid_using_api(
background_color = f"!{background_color}"
image_url = (
f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
f"{base_url}/img/{mermaid_syntax_encoded}"
f"?type={file_type}&bgColor={background_color}"
)
@@ -457,7 +465,7 @@ def _render_mermaid_using_api(
# For other status codes, fail immediately
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph. Status code: {response.status_code}.\n\n"
) + error_msg_suffix
raise ValueError(msg)
@@ -469,14 +477,14 @@ def _render_mermaid_using_api(
time.sleep(sleep_time)
else:
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. "
) + error_msg_suffix
raise ValueError(msg) from e
# This should not be reached, but just in case
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. "
) + error_msg_suffix
raise ValueError(msg)

View File

@@ -1,4 +1,5 @@
from typing import Any, Optional
from unittest.mock import MagicMock, patch
from packaging import version
from pydantic import BaseModel
@@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.graph import Edge, Graph, Node
from langchain_core.runnables.graph_mermaid import _escape_node_label
from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node
from langchain_core.runnables.graph_mermaid import (
_escape_node_label,
_render_mermaid_using_api,
draw_mermaid_png,
)
from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema
@@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
}
}
) == snapshot(name="mermaid")
def test_mermaid_base_url_default() -> None:
"""Test that _render_mermaid_using_api defaults to mermaid.ink when None."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call the function with base_url=None (default)
_render_mermaid_using_api(
"graph TD;\n A --> B;",
base_url=None,
)
# Verify that the URL was constructed with the default base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith("https://mermaid.ink")
def test_mermaid_base_url_custom() -> None:
"""Test that _render_mermaid_using_api uses custom base_url when provided."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call the function with custom base_url.
_render_mermaid_using_api(
"graph TD;\n A --> B;",
base_url=custom_url,
)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)
def test_draw_mermaid_png_function_base_url() -> None:
"""Test that draw_mermaid_png function passes base_url to API renderer."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call draw_mermaid_png with custom base_url
draw_mermaid_png(
"graph TD;\n A --> B;",
draw_method=MermaidDrawMethod.API,
base_url=custom_url,
)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)
def test_graph_draw_mermaid_png_base_url() -> None:
"""Test that Graph.draw_mermaid_png method passes base_url to renderer."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Create a simple graph
graph = Graph()
start_node = graph.add_node(BaseModel, id="start")
end_node = graph.add_node(BaseModel, id="end")
graph.add_edge(start_node, end_node)
# Call draw_mermaid_png with custom base_url
graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)