diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index cebf2a667c1..22f0b8ba35d 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -614,7 +614,6 @@ class Graph: Returns: The Mermaid syntax string. - """ # Import locally to prevent circular import from langchain_core.runnables.graph_mermaid import draw_mermaid # noqa: PLC0415 @@ -648,6 +647,7 @@ class Graph: max_retries: int = 1, retry_delay: float = 1.0, frontmatter_config: Optional[dict[str, Any]] = None, + base_url: Optional[str] = None, ) -> bytes: """Draw the graph as a PNG image using Mermaid. @@ -683,6 +683,8 @@ class Graph: "themeVariables": { "primaryColor": "#e2e2e2"}, } } + base_url: The base URL of the Mermaid server for rendering via API. + Defaults to None. Returns: The PNG image as bytes. @@ -707,6 +709,7 @@ class Graph: padding=padding, max_retries=max_retries, retry_delay=retry_delay, + base_url=base_url, ) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index df1468fc437..fe945dace4d 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -277,6 +277,7 @@ def draw_mermaid_png( padding: int = 10, max_retries: int = 1, retry_delay: float = 1.0, + base_url: Optional[str] = None, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. @@ -293,6 +294,8 @@ def draw_mermaid_png( Defaults to 1. retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API). Defaults to 1.0. + base_url (str, optional): Base URL for the Mermaid.ink API. + Defaults to None. Returns: bytes: PNG image bytes. @@ -313,6 +316,7 @@ def draw_mermaid_png( background_color=background_color, max_retries=max_retries, retry_delay=retry_delay, + base_url=base_url, ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) @@ -404,8 +408,12 @@ def _render_mermaid_using_api( file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", max_retries: int = 1, retry_delay: float = 1.0, + base_url: Optional[str] = None, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" + # Defaults to using the public mermaid.ink server. + base_url = base_url if base_url is not None else "https://mermaid.ink" + if not _HAS_REQUESTS: msg = ( "Install the `requests` module to use the Mermaid.INK API: " @@ -425,7 +433,7 @@ def _render_mermaid_using_api( background_color = f"!{background_color}" image_url = ( - f"https://mermaid.ink/img/{mermaid_syntax_encoded}" + f"{base_url}/img/{mermaid_syntax_encoded}" f"?type={file_type}&bgColor={background_color}" ) @@ -457,7 +465,7 @@ def _render_mermaid_using_api( # For other status codes, fail immediately msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph. Status code: {response.status_code}.\n\n" ) + error_msg_suffix raise ValueError(msg) @@ -469,14 +477,14 @@ def _render_mermaid_using_api( time.sleep(sleep_time) else: msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg) from e # This should not be reached, but just in case msg = ( - "Failed to reach https://mermaid.ink/ API while trying to render " + f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index fd9ff2f813e..398ae45e66e 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,4 +1,5 @@ from typing import Any, Optional +from unittest.mock import MagicMock, patch from packaging import version from pydantic import BaseModel @@ -12,8 +13,12 @@ from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import RunnableConfig from langchain_core.runnables.base import Runnable -from langchain_core.runnables.graph import Edge, Graph, Node -from langchain_core.runnables.graph_mermaid import _escape_node_label +from langchain_core.runnables.graph import Edge, Graph, MermaidDrawMethod, Node +from langchain_core.runnables.graph_mermaid import ( + _escape_node_label, + _render_mermaid_using_api, + draw_mermaid_png, +) from langchain_core.utils.pydantic import PYDANTIC_VERSION from tests.unit_tests.pydantic_utils import _normalize_schema @@ -561,3 +566,90 @@ def test_graph_mermaid_frontmatter_config(snapshot: SnapshotAssertion) -> None: } } ) == snapshot(name="mermaid") + + +def test_mermaid_base_url_default() -> None: + """Test that _render_mermaid_using_api defaults to mermaid.ink when None.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call the function with base_url=None (default) + _render_mermaid_using_api( + "graph TD;\n A --> B;", + base_url=None, + ) + + # Verify that the URL was constructed with the default base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith("https://mermaid.ink") + + +def test_mermaid_base_url_custom() -> None: + """Test that _render_mermaid_using_api uses custom base_url when provided.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call the function with custom base_url. + _render_mermaid_using_api( + "graph TD;\n A --> B;", + base_url=custom_url, + ) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url) + + +def test_draw_mermaid_png_function_base_url() -> None: + """Test that draw_mermaid_png function passes base_url to API renderer.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Call draw_mermaid_png with custom base_url + draw_mermaid_png( + "graph TD;\n A --> B;", + draw_method=MermaidDrawMethod.API, + base_url=custom_url, + ) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url) + + +def test_graph_draw_mermaid_png_base_url() -> None: + """Test that Graph.draw_mermaid_png method passes base_url to renderer.""" + custom_url = "https://custom.mermaid.com" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"fake image data" + + with patch("requests.get", return_value=mock_response) as mock_get: + # Create a simple graph + graph = Graph() + start_node = graph.add_node(BaseModel, id="start") + end_node = graph.add_node(BaseModel, id="end") + graph.add_edge(start_node, end_node) + + # Call draw_mermaid_png with custom base_url + graph.draw_mermaid_png(draw_method=MermaidDrawMethod.API, base_url=custom_url) + + # Verify that the URL was constructed with our custom base URL + assert mock_get.called + args, kwargs = mock_get.call_args + url = args[0] # First argument to request.get is the URL + assert url.startswith(custom_url)