diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index b0e7f3dd5de..c9b3025bc85 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -2,7 +2,7 @@ import asyncio import base64 import re from dataclasses import asdict -from typing import Optional +from typing import Literal, Optional from langchain_core.runnables.graph import ( CurveStyle, @@ -306,6 +306,7 @@ def _render_mermaid_using_api( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", + file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" try: @@ -329,7 +330,8 @@ def _render_mermaid_using_api( background_color = f"!{background_color}" image_url = ( - f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={background_color}" + f"https://mermaid.ink/img/{mermaid_syntax_encoded}" + f"?type={file_type}&bgColor={background_color}" ) response = requests.get(image_url, timeout=10) if response.status_code == 200: