From 25c34bd9b226d0635db3853fa11f55303079dfab Mon Sep 17 00:00:00 2001 From: Daniel Barker Date: Wed, 10 Sep 2025 16:14:50 -0500 Subject: [PATCH] feat(core): allow custom Mermaid URL (#32831) - **Description:** Currently, `langchain_core.runnables.graph_mermaid.py` is hardcoded to use mermaid.ink to render graph diagrams. It would be nice to allow users to specify a custom URL, e.g. for self-hosted instances of the Mermaid server. - **Issue:** [Langchain Forum: allow custom mermaid API URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472) - **Dependencies:** None - [X] **Add tests and docs**: Added unit tests using mock requests. - [X] **Lint and test**: Run `make format`, `make lint` and `make test`. Minimal example using the feature: ```python import os import operator from pathlib import Path from typing import Any, Annotated, TypedDict from langgraph.graph import StateGraph class State(TypedDict): messages: Annotated[list[dict[str, Any]], operator.add] def hello_node(state: State) -> State: return {"messages": [{"role": "assistant", "content": "pong!"}]} builder = StateGraph(State) builder.add_node("hello_node", hello_node) builder.add_edge("__start__", "hello_node") builder.add_edge("hello_node", "__end__") graph = builder.compile() # Run graph output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]}) # Draw graph Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink")) ``` --------- Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/runnables/graph.py | 5 +- .../langchain_core/runnables/graph_mermaid.py | 16 +++- .../tests/unit_tests/runnables/test_graph.py | 96 ++++++++++++++++++- 3 files changed, 110 insertions(+), 7 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index cebf2a667c1..22f0b8ba35d 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -614,7 +614,6 @@ class Graph: Returns: The Mermaid syntax string. - """ # Import locally to prevent circular import from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415 @@ -648,6 +647,7 @@ class Graph: max_retries: int = 1, retry_delay: float = 1.0, frontmatter_config: Optional[dict[str, Any]] = None, + base_url: Optional[str] = None, ) -> bytes: """Draw the graph as a PNG image using Mermaid. @@ -683,6 +683,8 @@ class Graph: "themeVariables": { "primaryColor": "#e2e2e2"}, } } + base_url: The base URL of the Mermaid server for rendering via API. + Defaults to None. Returns: The PNG image as bytes. @@ -707,6 +709,7 @@ class Graph: padding=padding, max_retries=max_retries, retry_delay=retry_delay, + base_url=base_url, ) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index df1468fc437..fe945dace4d 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -277,6 +277,7 @@ def draw_mermaid_png( padding: int = 10, max_retries: int = 1, retry_delay: float = 1.0, + base_url: Optional[str] = None, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. @@ -293,6 +294,8 @@ def draw_mermaid_png( Defaults to 1. retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API). Defaults to 1.0. + base_url (str, optional): Base URL for the Mermaid.ink API. + Defaults to None. Returns: bytes: PNG image bytes. @@ -313,6 +316,7 @@ def draw_mermaid_png( background_color=background_color, max_retries=max_retries, retry_delay=retry_delay, + base_url=base_url, ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) @@ -404,8 +408,12 @@ def _render_mermaid_using_api( file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", max_retries: int = 1, retry_delay: float = 1.0, + base_url: Optional[str] = None, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" + # Defaults to using the public mermaid.ink server. + base_url = base_url if base_url is not None else "https://mermaid.ink" + if not _HAS_REQUESTS: msg = ( "Install the `requests` module to use the Mermaid.INK API: " @@ -425,7 +433,7 @@ def _render_mermaid_using_api( background_color = f"!{background_color}" image_url = ( - f"https://mermaid.ink/img/{mermaid_syntax_encoded}" + f"{base_url}/img/{mermaid_syntax_encoded}" f"?type={file_type}&bgColor={background_color}" ) @@ -457,7 +465,7 @@ def _render_mermaid_using_api( # For other status codes, fail immediately msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph. Status code: {response.status_code}.\n\n" ) + error_msg_suffix raise ValueError(msg) @@ -469,14 +477,14 @@ def _render_mermaid_using_api( time.sleep(sleep_time) else: msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg) from e # This should not be reached, but just in case msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index fd9ff2f813e..398ae45e66e 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,4 +1,5 @@ from typing import Any, Optional +from unittest.mock import MagicMock, patch from packaging import version from pydantic import BaseModel @@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import RunnableConfig from langchain_core.runnables.base import Runnable -from langchain_core.runnables.graph import Edge, Graph, Node -from langchain_core.runnables.graph_mermaid import _escape_node_label +from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node +from langchain_core.runnables.graph_mermaid import ( + _escape_node_label, + _render_mermaid_using_api, + draw_mermaid_png, +) from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _normalize_schema @@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None: } } ) == snapshot(name="mermaid") + + +def test_mermaid_base_url_default() -> None: + """Test that _render_mermaid_using_api defaults to mermaid.ink when None.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call the function with base_url=None (default) + _render_mermaid_using_api( + "graph TD;\n A --> B;", + base_url=None, + ) + + # Verify that the URL was constructed with the default base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith("https://mermaid.ink") + + +def test_mermaid_base_url_custom() -> None: + """Test that _render_mermaid_using_api uses custom base_url when provided.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call the function with custom base_url. + _render_mermaid_using_api( + "graph TD;\n A --> B;", + base_url=custom_url, + ) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url) + + +def test_draw_mermaid_png_function_base_url() -> None: + """Test that draw_mermaid_png function passes base_url to API renderer.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call draw_mermaid_png with custom base_url + draw_mermaid_png( + "graph TD;\n A --> B;", + draw_method=MermaidDrawMethod.API, + base_url=custom_url, + ) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url) + + +def test_graph_draw_mermaid_png_base_url() -> None: + """Test that Graph.draw_mermaid_png method passes base_url to renderer.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Create a simple graph + graph = Graph() + start_node = graph.add_node(BaseModel, id="start") + end_node = graph.add_node(BaseModel, id="end") + graph.add_edge(start_node, end_node) + + # Call draw_mermaid_png with custom base_url + graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url)