diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 20997e52126..eeac8fc51ac 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -630,6 +630,8 @@ class Graph: draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: str = "white", padding: int = 10, + max_retries: int = 1, + retry_delay: float = 1.0, frontmatter_config: Optional[dict[str, Any]] = None, ) -> bytes: """Draw the graph as a PNG image using Mermaid. @@ -645,6 +647,10 @@ class Graph: Defaults to MermaidDrawMethod.API. background_color: The color of the background. Defaults to "white". padding: The padding around the graph. Defaults to 10. + max_retries: The maximum number of retries (MermaidDrawMethod.API). + Defaults to 1. + retry_delay: The delay between retries (MermaidDrawMethod.API). + Defaults to 1.0. frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. Can be used to customize theme and styles. Will be converted to YAML and added to the beginning of the mermaid graph. Defaults to None. @@ -680,6 +686,8 @@ class Graph: draw_method=draw_method, background_color=background_color, padding=padding, + max_retries=max_retries, + retry_delay=retry_delay, ) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index c47359e29a0..410c6c56652 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -2,7 +2,9 @@ import asyncio import base64 +import random import re +import time from dataclasses import asdict from pathlib import Path from typing import Any, Literal, Optional @@ -254,6 +256,8 @@ def draw_mermaid_png( draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: Optional[str] = "white", padding: int = 10, + max_retries: int = 1, + retry_delay: float = 1.0, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. @@ -266,6 +270,10 @@ def draw_mermaid_png( background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. + max_retries (int, optional): Maximum number of retries (MermaidDrawMethod.API). + Defaults to 1. + retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API). + Defaults to 1.0. Returns: bytes: PNG image bytes. @@ -283,7 +291,11 @@ def draw_mermaid_png( ) elif draw_method == MermaidDrawMethod.API: img_bytes = _render_mermaid_using_api( - mermaid_syntax, output_file_path, background_color + mermaid_syntax, + output_file_path=output_file_path, + background_color=background_color, + max_retries=max_retries, + retry_delay=retry_delay, ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) @@ -371,9 +383,12 @@ async def _render_mermaid_using_pyppeteer( def _render_mermaid_using_api( mermaid_syntax: str, + *, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", + max_retries: int = 1, + retry_delay: float = 1.0, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" try: @@ -400,15 +415,55 @@ def _render_mermaid_using_api( f"https://mermaid.ink/img/{mermaid_syntax_encoded}" f"?type={file_type}&bgColor={background_color}" ) - response = requests.get(image_url, timeout=10) - if response.status_code == requests.codes.ok: - img_bytes = response.content - if output_file_path is not None: - Path(output_file_path).write_bytes(response.content) - return img_bytes - msg = ( - f"Failed to render the graph using the Mermaid.INK API. " - f"Status code: {response.status_code}." + error_msg_suffix = ( + "To resolve this issue:\n" + "1. Check your internet connection and try again\n" + "2. Try with higher retry settings: " + "`draw_mermaid_png(..., max_retries=5, retry_delay=2.0)`\n" + "3. Use the Pyppeteer rendering method which will render your graph locally " + "in a browser: `draw_mermaid_png(..., draw_method=MermaidDrawMethod.PYPPETEER)`" ) + + for attempt in range(max_retries + 1): + try: + response = requests.get(image_url, timeout=10) + if response.status_code == requests.codes.ok: + img_bytes = response.content + if output_file_path is not None: + Path(output_file_path).write_bytes(response.content) + + return img_bytes + + # If we get a server error (5xx), retry + if 500 <= response.status_code < 600 and attempt < max_retries: + # Exponential backoff with jitter + sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto + time.sleep(sleep_time) + continue + + # For other status codes, fail immediately + msg = ( + "Failed to reach https://mermaid.ink/ API while trying to render " + f"your graph. Status code: {response.status_code}.\n\n" + ) + error_msg_suffix + raise ValueError(msg) + + except (requests.RequestException, requests.Timeout) as e: + if attempt < max_retries: + # Exponential backoff with jitter + sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto + time.sleep(sleep_time) + else: + msg = ( + "Failed to reach https://mermaid.ink/ API while trying to render " + f"your graph after {max_retries} retries. " + ) + error_msg_suffix + raise ValueError(msg) from e + + # This should not be reached, but just in case + msg = ( + "Failed to reach https://mermaid.ink/ API while trying to render " + f"your graph after {max_retries} retries. " + ) + error_msg_suffix raise ValueError(msg)