feat(core): allow custom Mermaid URL (#32831)

- **Description:** Currently,
`langchain_core.runnables.graph_mermaid.py` is hardcoded to use
mermaid.ink to render graph diagrams. It would be nice to allow users to
specify a custom URL, e.g. for self-hosted instances of the Mermaid
server.
- **Issue:** [Langchain Forum: allow custom mermaid API
URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472)
  - **Dependencies:** None

- [X] **Add tests and docs**: Added unit tests using mock requests.
- [X] **Lint and test**: Run `make format`, `make lint` and `make test`.

Minimal example using the feature:

```python
import os
import operator
from pathlib import Path
from typing import Any, Annotated, TypedDict

from langgraph.graph import StateGraph

class State(TypedDict):
    messages: Annotated[list[dict[str, Any]], operator.add]

def hello_node(state: State) -> State:
    return {"messages": [{"role": "assistant", "content": "pong!"}]}

builder = StateGraph(State)
builder.add_node("hello_node", hello_node)
builder.add_edge("__start__", "hello_node")
builder.add_edge("hello_node", "__end__")

graph = builder.compile()

# Run graph
output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]})

# Draw graph
Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink"))
```

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Daniel Barker
2025-09-10 16:14:50 -05:00
committed by GitHub
parent 38001699d5
commit 25c34bd9b2
3 changed files with 110 additions and 7 deletions

View File

@@ -614,7 +614,6 @@ class Graph:
Returns: Returns:
The Mermaid syntax string. The Mermaid syntax string.
""" """
# Import locally to prevent circular import # Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415 from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
@@ -648,6 +647,7 @@ class Graph:
max_retries: int = 1, max_retries: int = 1,
retry_delay: float = 1.0, retry_delay: float = 1.0,
frontmatter_config: Optional[dict[str, Any]] = None, frontmatter_config: Optional[dict[str, Any]] = None,
base_url: Optional[str] = None,
) -> bytes: ) -> bytes:
"""Draw the graph as a PNG image using Mermaid. """Draw the graph as a PNG image using Mermaid.
@@ -683,6 +683,8 @@ class Graph:
"themeVariables": { "primaryColor": "#e2e2e2"}, "themeVariables": { "primaryColor": "#e2e2e2"},
} }
} }
base_url: The base URL of the Mermaid server for rendering via API.
Defaults to None.
Returns: Returns:
The PNG image as bytes. The PNG image as bytes.
@@ -707,6 +709,7 @@ class Graph:
padding=padding, padding=padding,
max_retries=max_retries, max_retries=max_retries,
retry_delay=retry_delay, retry_delay=retry_delay,
base_url=base_url,
) )

View File

@@ -277,6 +277,7 @@ def draw_mermaid_png(
padding: int = 10, padding: int = 10,
max_retries: int = 1, max_retries: int = 1,
retry_delay: float = 1.0, retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes: ) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax. """Draws a Mermaid graph as PNG using provided syntax.
@@ -293,6 +294,8 @@ def draw_mermaid_png(
Defaults to 1. Defaults to 1.
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API). retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
Defaults to 1.0. Defaults to 1.0.
base_url (str, optional): Base URL for the Mermaid.ink API.
Defaults to None.
Returns: Returns:
bytes: PNG image bytes. bytes: PNG image bytes.
@@ -313,6 +316,7 @@ def draw_mermaid_png(
background_color=background_color, background_color=background_color,
max_retries=max_retries, max_retries=max_retries,
retry_delay=retry_delay, retry_delay=retry_delay,
base_url=base_url,
) )
else: else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -404,8 +408,12 @@ def _render_mermaid_using_api(
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
max_retries: int = 1, max_retries: int = 1,
retry_delay: float = 1.0, retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes: ) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API.""" """Renders Mermaid graph using the Mermaid.INK API."""
# Defaults to using the public mermaid.ink server.
base_url = base_url if base_url is not None else "https://mermaid.ink"
if not _HAS_REQUESTS: if not _HAS_REQUESTS:
msg = ( msg = (
"Install the `requests` module to use the Mermaid.INK API: " "Install the `requests` module to use the Mermaid.INK API: "
@@ -425,7 +433,7 @@ def _render_mermaid_using_api(
background_color = f"!{background_color}" background_color = f"!{background_color}"
image_url = ( image_url = (
f"https://mermaid.ink/img/{mermaid_syntax_encoded}" f"{base_url}/img/{mermaid_syntax_encoded}"
f"?type={file_type}&bgColor={background_color}" f"?type={file_type}&bgColor={background_color}"
) )
@@ -457,7 +465,7 @@ def _render_mermaid_using_api(
# For other status codes, fail immediately # For other status codes, fail immediately
msg = ( msg = (
"Failed to reach https://mermaid.ink/ API while trying to render " f"Failed to reach {base_url} API while trying to render "
f"your graph. Status code: {response.status_code}.\n\n" f"your graph. Status code: {response.status_code}.\n\n"
) + error_msg_suffix ) + error_msg_suffix
raise ValueError(msg) raise ValueError(msg)
@@ -469,14 +477,14 @@ def _render_mermaid_using_api(
time.sleep(sleep_time) time.sleep(sleep_time)
else: else:
msg = ( msg = (
"Failed to reach https://mermaid.ink/ API while trying to render " f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. " f"your graph after {max_retries} retries. "
) + error_msg_suffix ) + error_msg_suffix
raise ValueError(msg) from e raise ValueError(msg) from e
# This should not be reached, but just in case # This should not be reached, but just in case
msg = ( msg = (
"Failed to reach https://mermaid.ink/ API while trying to render " f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. " f"your graph after {max_retries} retries. "
) + error_msg_suffix ) + error_msg_suffix
raise ValueError(msg) raise ValueError(msg)

View File

@@ -1,4 +1,5 @@
from typing import Any, Optional from typing import Any, Optional
from unittest.mock import MagicMock, patch
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
@@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
from langchain_core.runnables.graph import Edge, Graph, Node from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node
from langchain_core.runnables.graph_mermaid import _escape_node_label from langchain_core.runnables.graph_mermaid import (
_escape_node_label,
_render_mermaid_using_api,
draw_mermaid_png,
)
from langchain_core.utils.pydantic import PYDANTIC_VERSION from langchain_core.utils.pydantic import PYDANTIC_VERSION
from tests.unit_tests.pydantic_utils import _normalize_schema from tests.unit_tests.pydantic_utils import _normalize_schema
@@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
} }
} }
) == snapshot(name="mermaid") ) == snapshot(name="mermaid")
def test_mermaid_base_url_default() -> None:
"""Test that _render_mermaid_using_api defaults to mermaid.ink when None."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call the function with base_url=None (default)
_render_mermaid_using_api(
"graph TD;\n A --> B;",
base_url=None,
)
# Verify that the URL was constructed with the default base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith("https://mermaid.ink")
def test_mermaid_base_url_custom() -> None:
"""Test that _render_mermaid_using_api uses custom base_url when provided."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call the function with custom base_url.
_render_mermaid_using_api(
"graph TD;\n A --> B;",
base_url=custom_url,
)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)
def test_draw_mermaid_png_function_base_url() -> None:
"""Test that draw_mermaid_png function passes base_url to API renderer."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Call draw_mermaid_png with custom base_url
draw_mermaid_png(
"graph TD;\n A --> B;",
draw_method=MermaidDrawMethod.API,
base_url=custom_url,
)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)
def test_graph_draw_mermaid_png_base_url() -> None:
"""Test that Graph.draw_mermaid_png method passes base_url to renderer."""
custom_url = "https://custom.mermaid.com"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b"fake image data"
with patch("requests.get", return_value=mock_response) as mock_get:
# Create a simple graph
graph = Graph()
start_node = graph.add_node(BaseModel, id="start")
end_node = graph.add_node(BaseModel, id="end")
graph.add_edge(start_node, end_node)
# Call draw_mermaid_png with custom base_url
graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url)
# Verify that the URL was constructed with our custom base URL
assert mock_get.called
args, kwargs = mock_get.call_args
url = args[0] # First argument to request.get is the URL
assert url.startswith(custom_url)