feat(core): allow custom Mermaid URL (#32831)

- **Description:** Currently,
`langchain_core.runnables.graph_mermaid.py` is hardcoded to use
mermaid.ink to render graph diagrams. It would be nice to allow users to
specify a custom URL, e.g. for self-hosted instances of the Mermaid
server.
- **Issue:** [Langchain Forum: allow custom mermaid API
URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472)
  - **Dependencies:** None

- [X] **Add tests and docs**: Added unit tests using mock requests.
- [X] **Lint and test**: Run `make format`, `make lint` and `make test`.

Minimal example using the feature:

```python
import os
import operator
from pathlib import Path
from typing import Any, Annotated, TypedDict

from langgraph.graph import StateGraph

class State(TypedDict):
    messages: Annotated[list[dict[str, Any]], operator.add]

def hello_node(state: State) -> State:
    return {"messages": [{"role": "assistant", "content": "pong!"}]}

builder = StateGraph(State)
builder.add_node("hello_node", hello_node)
builder.add_edge("__start__", "hello_node")
builder.add_edge("hello_node", "__end__")

graph = builder.compile()

# Run graph
output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]})

# Draw graph
Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink"))
```

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Daniel Barker
2025-09-10 16:14:50 -05:00
committed by GitHub
parent 38001699d5
commit 25c34bd9b2
3 changed files with 110 additions and 7 deletions

View File

@@ -614,7 +614,6 @@ class Graph:
Returns:
The Mermaid syntax string.
"""
# Import locally to prevent circular import
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
@@ -648,6 +647,7 @@ class Graph:
max_retries: int = 1,
retry_delay: float = 1.0,
frontmatter_config: Optional[dict[str, Any]] = None,
base_url: Optional[str] = None,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid.
@@ -683,6 +683,8 @@ class Graph:
"themeVariables": { "primaryColor": "#e2e2e2"},
}
}
base_url: The base URL of the Mermaid server for rendering via API.
Defaults to None.
Returns:
The PNG image as bytes.
@@ -707,6 +709,7 @@ class Graph:
padding=padding,
max_retries=max_retries,
retry_delay=retry_delay,
base_url=base_url,
)

View File

@@ -277,6 +277,7 @@ def draw_mermaid_png(
padding: int = 10,
max_retries: int = 1,
retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax.
@@ -293,6 +294,8 @@ def draw_mermaid_png(
Defaults to 1.
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
Defaults to 1.0.
base_url (str, optional): Base URL for the Mermaid.ink API.
Defaults to None.
Returns:
bytes: PNG image bytes.
@@ -313,6 +316,7 @@ def draw_mermaid_png(
background_color=background_color,
max_retries=max_retries,
retry_delay=retry_delay,
base_url=base_url,
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
@@ -404,8 +408,12 @@ def _render_mermaid_using_api(
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
max_retries: int = 1,
retry_delay: float = 1.0,
base_url: Optional[str] = None,
) -> bytes:
"""Renders Mermaid graph using the Mermaid.INK API."""
# Defaults to using the public mermaid.ink server.
base_url = base_url if base_url is not None else "https://mermaid.ink"
if not _HAS_REQUESTS:
msg = (
"Install the `requests` module to use the Mermaid.INK API: "
@@ -425,7 +433,7 @@ def _render_mermaid_using_api(
background_color = f"!{background_color}"
image_url = (
f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
f"{base_url}/img/{mermaid_syntax_encoded}"
f"?type={file_type}&bgColor={background_color}"
)
@@ -457,7 +465,7 @@ def _render_mermaid_using_api(
# For other status codes, fail immediately
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph. Status code: {response.status_code}.\n\n"
) + error_msg_suffix
raise ValueError(msg)
@@ -469,14 +477,14 @@ def _render_mermaid_using_api(
time.sleep(sleep_time)
else:
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. "
) + error_msg_suffix
raise ValueError(msg) from e
# This should not be reached, but just in case
msg = (
"Failed to reach https://mermaid.ink/ API while trying to render "
f"Failed to reach {base_url} API while trying to render "
f"your graph after {max_retries} retries. "
) + error_msg_suffix
raise ValueError(msg)