mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
feat(core): allow custom Mermaid URL (#32831)
- **Description:** Currently, `langchain_core.runnables.graph_mermaid.py` is hardcoded to use mermaid.ink to render graph diagrams. It would be nice to allow users to specify a custom URL, e.g. for self-hosted instances of the Mermaid server. - **Issue:** [Langchain Forum: allow custom mermaid API URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472) - **Dependencies:** None - [X] **Add tests and docs**: Added unit tests using mock requests. - [X] **Lint and test**: Run `make format`, `make lint` and `make test`. Minimal example using the feature: ```python import os import operator from pathlib import Path from typing import Any, Annotated, TypedDict from langgraph.graph import StateGraph class State(TypedDict): messages: Annotated[list[dict[str, Any]], operator.add] def hello_node(state: State) -> State: return {"messages": [{"role": "assistant", "content": "pong!"}]} builder = StateGraph(State) builder.add_node("hello_node", hello_node) builder.add_edge("__start__", "hello_node") builder.add_edge("hello_node", "__end__") graph = builder.compile() # Run graph output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]}) # Draw graph Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink")) ``` --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -614,7 +614,6 @@ class Graph:
|
||||
|
||||
Returns:
|
||||
The Mermaid syntax string.
|
||||
|
||||
"""
|
||||
# Import locally to prevent circular import
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
|
||||
@@ -648,6 +647,7 @@ class Graph:
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
@@ -683,6 +683,8 @@ class Graph:
|
||||
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||
}
|
||||
}
|
||||
base_url: The base URL of the Mermaid server for rendering via API.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
@@ -707,6 +709,7 @@ class Graph:
|
||||
padding=padding,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -277,6 +277,7 @@ def draw_mermaid_png(
|
||||
padding: int = 10,
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
base_url: Optional[str] = None,
|
||||
) -> bytes:
|
||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||
|
||||
@@ -293,6 +294,8 @@ def draw_mermaid_png(
|
||||
Defaults to 1.
|
||||
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
|
||||
Defaults to 1.0.
|
||||
base_url (str, optional): Base URL for the Mermaid.ink API.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
bytes: PNG image bytes.
|
||||
@@ -313,6 +316,7 @@ def draw_mermaid_png(
|
||||
background_color=background_color,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_url=base_url,
|
||||
)
|
||||
else:
|
||||
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
||||
@@ -404,8 +408,12 @@ def _render_mermaid_using_api(
|
||||
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
|
||||
max_retries: int = 1,
|
||||
retry_delay: float = 1.0,
|
||||
base_url: Optional[str] = None,
|
||||
) -> bytes:
|
||||
"""Renders Mermaid graph using the Mermaid.INK API."""
|
||||
# Defaults to using the public mermaid.ink server.
|
||||
base_url = base_url if base_url is not None else "https://mermaid.ink"
|
||||
|
||||
if not _HAS_REQUESTS:
|
||||
msg = (
|
||||
"Install the `requests` module to use the Mermaid.INK API: "
|
||||
@@ -425,7 +433,7 @@ def _render_mermaid_using_api(
|
||||
background_color = f"!{background_color}"
|
||||
|
||||
image_url = (
|
||||
f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
|
||||
f"{base_url}/img/{mermaid_syntax_encoded}"
|
||||
f"?type={file_type}&bgColor={background_color}"
|
||||
)
|
||||
|
||||
@@ -457,7 +465,7 @@ def _render_mermaid_using_api(
|
||||
|
||||
# For other status codes, fail immediately
|
||||
msg = (
|
||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
||||
f"Failed to reach {base_url} API while trying to render "
|
||||
f"your graph. Status code: {response.status_code}.\n\n"
|
||||
) + error_msg_suffix
|
||||
raise ValueError(msg)
|
||||
@@ -469,14 +477,14 @@ def _render_mermaid_using_api(
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
msg = (
|
||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
||||
f"Failed to reach {base_url} API while trying to render "
|
||||
f"your graph after {max_retries} retries. "
|
||||
) + error_msg_suffix
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# This should not be reached, but just in case
|
||||
msg = (
|
||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
||||
f"Failed to reach {base_url} API while trying to render "
|
||||
f"your graph after {max_retries} retries. "
|
||||
) + error_msg_suffix
|
||||
raise ValueError(msg)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
@@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables.graph import Edge, Graph, Node
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node
|
||||
from langchain_core.runnables.graph_mermaid import (
|
||||
_escape_node_label,
|
||||
_render_mermaid_using_api,
|
||||
draw_mermaid_png,
|
||||
)
|
||||
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||
|
||||
@@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
}
|
||||
) == snapshot(name="mermaid")
|
||||
|
||||
|
||||
def test_mermaid_base_url_default() -> None:
|
||||
"""Test that _render_mermaid_using_api defaults to mermaid.ink when None."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake image data"
|
||||
|
||||
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||
# Call the function with base_url=None (default)
|
||||
_render_mermaid_using_api(
|
||||
"graph TD;\n A --> B;",
|
||||
base_url=None,
|
||||
)
|
||||
|
||||
# Verify that the URL was constructed with the default base URL
|
||||
assert mock_get.called
|
||||
args, kwargs = mock_get.call_args
|
||||
url = args[0] # First argument to request.get is the URL
|
||||
assert url.startswith("https://mermaid.ink")
|
||||
|
||||
|
||||
def test_mermaid_base_url_custom() -> None:
|
||||
"""Test that _render_mermaid_using_api uses custom base_url when provided."""
|
||||
custom_url = "https://custom.mermaid.com"
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake image data"
|
||||
|
||||
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||
# Call the function with custom base_url.
|
||||
_render_mermaid_using_api(
|
||||
"graph TD;\n A --> B;",
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify that the URL was constructed with our custom base URL
|
||||
assert mock_get.called
|
||||
args, kwargs = mock_get.call_args
|
||||
url = args[0] # First argument to request.get is the URL
|
||||
assert url.startswith(custom_url)
|
||||
|
||||
|
||||
def test_draw_mermaid_png_function_base_url() -> None:
|
||||
"""Test that draw_mermaid_png function passes base_url to API renderer."""
|
||||
custom_url = "https://custom.mermaid.com"
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake image data"
|
||||
|
||||
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||
# Call draw_mermaid_png with custom base_url
|
||||
draw_mermaid_png(
|
||||
"graph TD;\n A --> B;",
|
||||
draw_method=MermaidDrawMethod.API,
|
||||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify that the URL was constructed with our custom base URL
|
||||
assert mock_get.called
|
||||
args, kwargs = mock_get.call_args
|
||||
url = args[0] # First argument to request.get is the URL
|
||||
assert url.startswith(custom_url)
|
||||
|
||||
|
||||
def test_graph_draw_mermaid_png_base_url() -> None:
|
||||
"""Test that Graph.draw_mermaid_png method passes base_url to renderer."""
|
||||
custom_url = "https://custom.mermaid.com"
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake image data"
|
||||
|
||||
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||
# Create a simple graph
|
||||
graph = Graph()
|
||||
start_node = graph.add_node(BaseModel, id="start")
|
||||
end_node = graph.add_node(BaseModel, id="end")
|
||||
graph.add_edge(start_node, end_node)
|
||||
|
||||
# Call draw_mermaid_png with custom base_url
|
||||
graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url)
|
||||
|
||||
# Verify that the URL was constructed with our custom base URL
|
||||
assert mock_get.called
|
||||
args, kwargs = mock_get.call_args
|
||||
url = args[0] # First argument to request.get is the URL
|
||||
assert url.startswith(custom_url)
|
||||
|
Reference in New Issue
Block a user