mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
feat(core): allow custom Mermaid URL (#32831)
- **Description:** Currently, `langchain_core.runnables.graph_mermaid.py` is hardcoded to use mermaid.ink to render graph diagrams. It would be nice to allow users to specify a custom URL, e.g. for self-hosted instances of the Mermaid server. - **Issue:** [Langchain Forum: allow custom mermaid API URL](https://forum.langchain.com/t/feature-request-allow-custom-mermaid-api-url/1472) - **Dependencies:** None - [X] **Add tests and docs**: Added unit tests using mock requests. - [X] **Lint and test**: Run `make format`, `make lint` and `make test`. Minimal example using the feature: ```python import os import operator from pathlib import Path from typing import Any, Annotated, TypedDict from langgraph.graph import StateGraph class State(TypedDict): messages: Annotated[list[dict[str, Any]], operator.add] def hello_node(state: State) -> State: return {"messages": [{"role": "assistant", "content": "pong!"}]} builder = StateGraph(State) builder.add_node("hello_node", hello_node) builder.add_edge("__start__", "hello_node") builder.add_edge("hello_node", "__end__") graph = builder.compile() # Run graph output = graph.invoke({"messages": [{"role": "user", "content": "ping?"}]}) # Draw graph Path("graph.png").write_bytes(graph.get_graph().draw_mermaid_png(base_url="https://custom-mermaid.ink")) ``` --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -614,7 +614,6 @@ class Graph:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The Mermaid syntax string.
|
The Mermaid syntax string.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Import locally to prevent circular import
|
# Import locally to prevent circular import
|
||||||
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
|
from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415
|
||||||
@@ -648,6 +647,7 @@ class Graph:
|
|||||||
max_retries: int = 1,
|
max_retries: int = 1,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
frontmatter_config: Optional[dict[str, Any]] = None,
|
frontmatter_config: Optional[dict[str, Any]] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Draw the graph as a PNG image using Mermaid.
|
"""Draw the graph as a PNG image using Mermaid.
|
||||||
|
|
||||||
@@ -683,6 +683,8 @@ class Graph:
|
|||||||
"themeVariables": { "primaryColor": "#e2e2e2"},
|
"themeVariables": { "primaryColor": "#e2e2e2"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
base_url: The base URL of the Mermaid server for rendering via API.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The PNG image as bytes.
|
The PNG image as bytes.
|
||||||
@@ -707,6 +709,7 @@ class Graph:
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
retry_delay=retry_delay,
|
retry_delay=retry_delay,
|
||||||
|
base_url=base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -277,6 +277,7 @@ def draw_mermaid_png(
|
|||||||
padding: int = 10,
|
padding: int = 10,
|
||||||
max_retries: int = 1,
|
max_retries: int = 1,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||||
|
|
||||||
@@ -293,6 +294,8 @@ def draw_mermaid_png(
|
|||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
|
retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API).
|
||||||
Defaults to 1.0.
|
Defaults to 1.0.
|
||||||
|
base_url (str, optional): Base URL for the Mermaid.ink API.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bytes: PNG image bytes.
|
bytes: PNG image bytes.
|
||||||
@@ -313,6 +316,7 @@ def draw_mermaid_png(
|
|||||||
background_color=background_color,
|
background_color=background_color,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
retry_delay=retry_delay,
|
retry_delay=retry_delay,
|
||||||
|
base_url=base_url,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
|
||||||
@@ -404,8 +408,12 @@ def _render_mermaid_using_api(
|
|||||||
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
|
file_type: Optional[Literal["jpeg", "png", "webp"]] = "png",
|
||||||
max_retries: int = 1,
|
max_retries: int = 1,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Renders Mermaid graph using the Mermaid.INK API."""
|
"""Renders Mermaid graph using the Mermaid.INK API."""
|
||||||
|
# Defaults to using the public mermaid.ink server.
|
||||||
|
base_url = base_url if base_url is not None else "https://mermaid.ink"
|
||||||
|
|
||||||
if not _HAS_REQUESTS:
|
if not _HAS_REQUESTS:
|
||||||
msg = (
|
msg = (
|
||||||
"Install the `requests` module to use the Mermaid.INK API: "
|
"Install the `requests` module to use the Mermaid.INK API: "
|
||||||
@@ -425,7 +433,7 @@ def _render_mermaid_using_api(
|
|||||||
background_color = f"!{background_color}"
|
background_color = f"!{background_color}"
|
||||||
|
|
||||||
image_url = (
|
image_url = (
|
||||||
f"https://mermaid.ink/img/{mermaid_syntax_encoded}"
|
f"{base_url}/img/{mermaid_syntax_encoded}"
|
||||||
f"?type={file_type}&bgColor={background_color}"
|
f"?type={file_type}&bgColor={background_color}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -457,7 +465,7 @@ def _render_mermaid_using_api(
|
|||||||
|
|
||||||
# For other status codes, fail immediately
|
# For other status codes, fail immediately
|
||||||
msg = (
|
msg = (
|
||||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
f"Failed to reach {base_url} API while trying to render "
|
||||||
f"your graph. Status code: {response.status_code}.\n\n"
|
f"your graph. Status code: {response.status_code}.\n\n"
|
||||||
) + error_msg_suffix
|
) + error_msg_suffix
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@@ -469,14 +477,14 @@ def _render_mermaid_using_api(
|
|||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
f"Failed to reach {base_url} API while trying to render "
|
||||||
f"your graph after {max_retries} retries. "
|
f"your graph after {max_retries} retries. "
|
||||||
) + error_msg_suffix
|
) + error_msg_suffix
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
# This should not be reached, but just in case
|
# This should not be reached, but just in case
|
||||||
msg = (
|
msg = (
|
||||||
"Failed to reach https://mermaid.ink/ API while trying to render "
|
f"Failed to reach {base_url} API while trying to render "
|
||||||
f"your graph after {max_retries} retries. "
|
f"your graph after {max_retries} retries. "
|
||||||
) + error_msg_suffix
|
) + error_msg_suffix
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser
|
|||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
from langchain_core.runnables.graph import Edge, Graph, Node
|
from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node
|
||||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
from langchain_core.runnables.graph_mermaid import (
|
||||||
|
_escape_node_label,
|
||||||
|
_render_mermaid_using_api,
|
||||||
|
draw_mermaid_png,
|
||||||
|
)
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||||
|
|
||||||
@@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
) == snapshot(name="mermaid")
|
) == snapshot(name="mermaid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mermaid_base_url_default() -> None:
|
||||||
|
"""Test that _render_mermaid_using_api defaults to mermaid.ink when None."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"fake image data"
|
||||||
|
|
||||||
|
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||||
|
# Call the function with base_url=None (default)
|
||||||
|
_render_mermaid_using_api(
|
||||||
|
"graph TD;\n A --> B;",
|
||||||
|
base_url=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the URL was constructed with the default base URL
|
||||||
|
assert mock_get.called
|
||||||
|
args, kwargs = mock_get.call_args
|
||||||
|
url = args[0] # First argument to request.get is the URL
|
||||||
|
assert url.startswith("https://mermaid.ink")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mermaid_base_url_custom() -> None:
|
||||||
|
"""Test that _render_mermaid_using_api uses custom base_url when provided."""
|
||||||
|
custom_url = "https://custom.mermaid.com"
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"fake image data"
|
||||||
|
|
||||||
|
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||||
|
# Call the function with custom base_url.
|
||||||
|
_render_mermaid_using_api(
|
||||||
|
"graph TD;\n A --> B;",
|
||||||
|
base_url=custom_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the URL was constructed with our custom base URL
|
||||||
|
assert mock_get.called
|
||||||
|
args, kwargs = mock_get.call_args
|
||||||
|
url = args[0] # First argument to request.get is the URL
|
||||||
|
assert url.startswith(custom_url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_draw_mermaid_png_function_base_url() -> None:
|
||||||
|
"""Test that draw_mermaid_png function passes base_url to API renderer."""
|
||||||
|
custom_url = "https://custom.mermaid.com"
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"fake image data"
|
||||||
|
|
||||||
|
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||||
|
# Call draw_mermaid_png with custom base_url
|
||||||
|
draw_mermaid_png(
|
||||||
|
"graph TD;\n A --> B;",
|
||||||
|
draw_method=MermaidDrawMethod.API,
|
||||||
|
base_url=custom_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the URL was constructed with our custom base URL
|
||||||
|
assert mock_get.called
|
||||||
|
args, kwargs = mock_get.call_args
|
||||||
|
url = args[0] # First argument to request.get is the URL
|
||||||
|
assert url.startswith(custom_url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_draw_mermaid_png_base_url() -> None:
|
||||||
|
"""Test that Graph.draw_mermaid_png method passes base_url to renderer."""
|
||||||
|
custom_url = "https://custom.mermaid.com"
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.content = b"fake image data"
|
||||||
|
|
||||||
|
with patch("requests.get", return_value=mock_response) as mock_get:
|
||||||
|
# Create a simple graph
|
||||||
|
graph = Graph()
|
||||||
|
start_node = graph.add_node(BaseModel, id="start")
|
||||||
|
end_node = graph.add_node(BaseModel, id="end")
|
||||||
|
graph.add_edge(start_node, end_node)
|
||||||
|
|
||||||
|
# Call draw_mermaid_png with custom base_url
|
||||||
|
graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url)
|
||||||
|
|
||||||
|
# Verify that the URL was constructed with our custom base URL
|
||||||
|
assert mock_get.called
|
||||||
|
args, kwargs = mock_get.call_args
|
||||||
|
url = args[0] # First argument to request.get is the URL
|
||||||
|
assert url.startswith(custom_url)
|
||||||
|
Reference in New Issue
Block a user