diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 557a2028c38..5c7a8e9985d 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -2,9 +2,11 @@ from __future__ import annotations import inspect from dataclasses import dataclass, field +from enum import Enum from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, NamedTuple, @@ -51,6 +53,46 @@ class Node(NamedTuple): data: Union[Type[BaseModel], RunnableType] +class Branch(NamedTuple): + """Branch in a graph.""" + + condition: Callable[..., str] + ends: Optional[dict[str, str]] + + +class CurveStyle(Enum): + """Enum for different curve styles supported by Mermaid""" + + BASIS = "basis" + BUMP_X = "bumpX" + BUMP_Y = "bumpY" + CARDINAL = "cardinal" + CATMULL_ROM = "catmullRom" + LINEAR = "linear" + MONOTONE_X = "monotoneX" + MONOTONE_Y = "monotoneY" + NATURAL = "natural" + STEP = "step" + STEP_AFTER = "stepAfter" + STEP_BEFORE = "stepBefore" + + +@dataclass +class NodeColors: + """Schema for Hexadecimal color codes for different node types""" + + start: str = "#ffdfba" + end: str = "#baffc9" + other: str = "#fad7de" + + +class MermaidDrawMethod(Enum): + """Enum for different draw methods supported by Mermaid""" + + PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph + API = "api" # Uses Mermaid.INK API to render the graph + + def node_data_str(node: Node) -> str: from langchain_core.runnables.base import Runnable @@ -112,6 +154,7 @@ class Graph: nodes: Dict[str, Node] = field(default_factory=dict) edges: List[Edge] = field(default_factory=list) + branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict) def to_json(self) -> Dict[str, List[Dict[str, Any]]]: """Convert the graph to a JSON-serializable format.""" @@ -277,3 +320,59 @@ class Graph: edges=labels["edges"] if labels is not None else {}, ), ).draw(self, output_file_path) + + def draw_mermaid( + self, + curve_style: CurveStyle = CurveStyle.LINEAR, + node_colors: NodeColors = NodeColors( + start="#ffdfba", end="#baffc9", other="#fad7de" + ), + wrap_label_n_words: int = 9, + ) -> str: + from langchain_core.runnables.graph_mermaid import draw_mermaid + + nodes = {node.id: node_data_str(node) for node in self.nodes.values()} + + first_node = self.first_node() + first_label = node_data_str(first_node) if first_node is not None else None + + last_node = self.last_node() + last_label = node_data_str(last_node) if last_node is not None else None + + return draw_mermaid( + nodes=nodes, + edges=self.edges, + branches=self.branches, + first_node_label=first_label, + last_node_label=last_label, + curve_style=curve_style, + node_colors=node_colors, + wrap_label_n_words=wrap_label_n_words, + ) + + def draw_mermaid_png( + self, + curve_style: CurveStyle = CurveStyle.LINEAR, + node_colors: NodeColors = NodeColors( + start="#ffdfba", end="#baffc9", other="#fad7de" + ), + wrap_label_n_words: int = 9, + output_file_path: str = "graph.png", + draw_method: MermaidDrawMethod = MermaidDrawMethod.API, + background_color: str = "white", + padding: int = 10, + ) -> None: + from langchain_core.runnables.graph_mermaid import draw_mermaid_png + + mermaid_syntax = self.draw_mermaid( + curve_style=curve_style, + node_colors=node_colors, + wrap_label_n_words=wrap_label_n_words, + ) + draw_mermaid_png( + mermaid_syntax=mermaid_syntax, + output_file_path=output_file_path, + draw_method=draw_method, + background_color=background_color, + padding=padding, + ) diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py new file mode 100644 index 00000000000..ad7f012bdb9 --- /dev/null +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -0,0 +1,292 @@ +import base64 +import re +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple + +from langchain_core.runnables.graph import ( + Branch, + CurveStyle, + Edge, + MermaidDrawMethod, + NodeColors, +) + + +def draw_mermaid( + nodes: Dict[str, str], + edges: List[Edge], + branches: Optional[Dict[str, List[Branch]]] = None, + first_node_label: Optional[str] = None, + last_node_label: Optional[str] = None, + curve_style: CurveStyle = CurveStyle.LINEAR, + node_colors: NodeColors = NodeColors(), + wrap_label_n_words: int = 9, +) -> str: + """Draws a Mermaid graph using the provided graph data + + Args: + nodes (dict[str, str]): List of node ids + edges (List[Edge]): List of edges, object with source, + target and data. + branches (defaultdict[str, list[Branch]]): Branches for the graph ( + in case of langgraph) to remove intermediate condition nodes. + curve_style (CurveStyle, optional): Curve style for the edges. + node_colors (NodeColors, optional): Node colors for different types. + wrap_label_n_words (int, optional): Words to wrap the edge labels. + + Returns: + str: Mermaid graph syntax + """ + # Initialize Mermaid graph configuration + mermaid_graph = ( + f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" + f"}}}}}}%%\ngraph TD;\n" + ) + + # Node formatting templates + default_class_label = "default" + format_dict = {default_class_label: "{0}([{0}]):::otherclass"} + if first_node_label is not None: + format_dict[first_node_label] = "{0}[{0}]:::startclass" + if last_node_label is not None: + format_dict[last_node_label] = "{0}[{0}]:::endclass" + + # Filter out nodes that were created due to conditional edges + # Remove combinations where node name is the same as a branch + condition + mapping_intermediate_node_pure_node = {} + if branches is not None: + for agent, agent_branches in branches.items(): + for branch in agent_branches: + condition_name = branch.condition.__name__ + intermediate_node_label = f"{agent}_{condition_name}" + if intermediate_node_label in nodes: + mapping_intermediate_node_pure_node[intermediate_node_label] = agent + + # Not intermediate nodes + pure_nodes = { + id: value + for id, value in nodes.items() + if value not in mapping_intermediate_node_pure_node.keys() + } + + # Add __end__ node if it is in any of the edges.target + if any("__end__" in edge.target for edge in edges): + pure_nodes["__end__"] = "__end__" + + # Add nodes to the graph + for node in pure_nodes.values(): + node_label = format_dict.get(node, format_dict[default_class_label]).format( + _escape_node_label(node) + ) + mermaid_graph += f"\t{node_label};\n" + + # Add edges to the graph + for edge in edges: + adjusted_edge = _adjust_mermaid_edge( + edge, nodes, mapping_intermediate_node_pure_node + ) + if ( + adjusted_edge is None + ): # Ignore if it is connection between source and intermediate node + continue + + source, target = adjusted_edge + + # Add BR every wrap_label_n_words words + if edge.data is not None: + edge_data = edge.data + words = edge_data.split() # Split the string into words + # Group words into chunks of wrap_label_n_words size + if len(words) > wrap_label_n_words: + edge_data = "
".join( + [ + " ".join(words[i : i + wrap_label_n_words]) + for i in range(0, len(words), wrap_label_n_words) + ] + ) + edge_label = f" -- {edge_data} --> " + else: + edge_label = " --> " + mermaid_graph += ( + f"\t{_escape_node_label(source)}{edge_label}" + f"{_escape_node_label(target)};\n" + ) + + # Add custom styles for nodes + mermaid_graph += _generate_mermaid_graph_styles(node_colors) + return mermaid_graph + + +def _escape_node_label(node_label: str) -> str: + """Escapes the node label for Mermaid syntax.""" + return re.sub(r"[^a-zA-Z-_]", "_", node_label) + + +def _adjust_mermaid_edge( + edge: Edge, + nodes: Dict[str, str], + mapping_intermediate_node_pure_node: Dict[str, str], +) -> Optional[Tuple[str, str]]: + """Adjusts Mermaid edge to map conditional nodes to pure nodes.""" + source_node_label = nodes.get(edge.source, edge.source) + target_node_label = nodes.get(edge.target, edge.target) + + # Remove nodes between source node to intermediate node + if target_node_label in mapping_intermediate_node_pure_node.keys(): + return None + + # Replace intermediate nodes by source nodes + if source_node_label in mapping_intermediate_node_pure_node.keys(): + source_node_label = mapping_intermediate_node_pure_node[source_node_label] + + return source_node_label, target_node_label + + +def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str: + """Generates Mermaid graph styles for different node types.""" + styles = "" + for class_name, color in asdict(node_colors).items(): + styles += f"\tclassDef {class_name}class fill:{color};\n" + return styles + + +def draw_mermaid_png( + mermaid_syntax: str, + output_file_path: Optional[str] = None, + draw_method: MermaidDrawMethod = MermaidDrawMethod.API, + background_color: Optional[str] = "white", + padding: int = 10, +) -> bytes: + """Draws a Mermaid graph as PNG using provided syntax.""" + if draw_method == MermaidDrawMethod.PYPPETEER: + import asyncio + + img_bytes = asyncio.run( + _render_mermaid_using_pyppeteer( + mermaid_syntax, output_file_path, background_color, padding + ) + ) + elif draw_method == MermaidDrawMethod.API: + img_bytes = _render_mermaid_using_api( + mermaid_syntax, output_file_path, background_color + ) + else: + supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) + raise ValueError( + f"Invalid draw method: {draw_method}. " + f"Supported draw methods are: {supported_methods}" + ) + + return img_bytes + + +async def _render_mermaid_using_pyppeteer( + mermaid_syntax: str, + output_file_path: Optional[str] = None, + background_color: Optional[str] = "white", + padding: int = 10, +) -> bytes: + """Renders Mermaid graph using Pyppeteer.""" + try: + from pyppeteer import launch # type: ignore[import] + except ImportError as e: + raise ImportError( + "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." + ) from e + + browser = await launch() + page = await browser.newPage() + + # Setup Mermaid JS + await page.goto("about:blank") + await page.addScriptTag({"url": "https://unpkg.com/mermaid/dist/mermaid.min.js"}) + await page.evaluate( + """() => { + mermaid.initialize({startOnLoad:true}); + }""" + ) + + # Render SVG + svg_code = await page.evaluate( + """(mermaidGraph) => { + return mermaid.mermaidAPI.render('mermaid', mermaidGraph); + }""", + mermaid_syntax, + ) + + # Set the page background to white + await page.evaluate( + """(svg, background_color) => { + document.body.innerHTML = svg; + document.body.style.background = background_color; + }""", + svg_code["svg"], + background_color, + ) + + # Take a screenshot + dimensions = await page.evaluate( + """() => { + const svgElement = document.querySelector('svg'); + const rect = svgElement.getBoundingClientRect(); + return { width: rect.width, height: rect.height }; + }""" + ) + await page.setViewport( + { + "width": int(dimensions["width"] + padding), + "height": int(dimensions["height"] + padding), + } + ) + + img_bytes = await page.screenshot({"fullPage": False}) + await browser.close() + + if output_file_path is not None: + with open(output_file_path, "wb") as file: + file.write(img_bytes) + + return img_bytes + + +def _render_mermaid_using_api( + mermaid_syntax: str, + output_file_path: Optional[str] = None, + background_color: Optional[str] = "white", +) -> bytes: + """Renders Mermaid graph using the Mermaid.INK API.""" + try: + import requests # type: ignore[import] + except ImportError as e: + raise ImportError( + "Install the `requests` module to use the Mermaid.INK API: " + "`pip install requests`." + ) from e + + # Use Mermaid API to render the image + mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( + "ascii" + ) + + # Check if the background color is a hexadecimal color code using regex + if background_color is not None: + hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") + if not hex_color_pattern.match(background_color): + background_color = f"!{background_color}" + + image_url = ( + f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={background_color}" + ) + response = requests.get(image_url) + if response.status_code == 200: + img_bytes = response.content + if output_file_path is not None: + with open(output_file_path, "wb") as file: + file.write(response.content) + + return img_bytes + else: + raise ValueError( + f"Failed to render the graph using the Mermaid.INK API. " + f"Status code: {response.status_code}." + ) diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index a76baa67871..b887d8d1f14 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_graph_sequence +# name: test_graph_sequence[ascii] ''' +-------------+ | PromptInput | @@ -30,7 +30,26 @@ +--------------------------------------+ ''' # --- -# name: test_graph_sequence_map +# name: test_graph_sequence[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + PromptInput[PromptInput]:::startclass; + PromptTemplate([PromptTemplate]):::otherclass; + FakeListLLM([FakeListLLM]):::otherclass; + CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; + CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; + PromptInput --> PromptTemplate; + PromptTemplate --> FakeListLLM; + CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput; + FakeListLLM --> CommaSeparatedListOutputParser; + classDef startclass fill:#ffdfba; + classDef endclass fill:#baffc9; + classDef otherclass fill:#fad7de; + + ''' +# --- +# name: test_graph_sequence_map[ascii] ''' +-------------+ | PromptInput | @@ -79,7 +98,38 @@ +--------------------------------+ ''' # --- -# name: test_graph_single_runnable +# name: test_graph_sequence_map[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + PromptInput[PromptInput]:::startclass; + PromptTemplate([PromptTemplate]):::otherclass; + FakeListLLM([FakeListLLM]):::otherclass; + Parallel_as_list_as_str_Input([Parallel_as_list_as_str_Input]):::otherclass; + Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass; + CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; + conditional_str_parser_input([conditional_str_parser_input]):::otherclass; + conditional_str_parser_output([conditional_str_parser_output]):::otherclass; + StrOutputParser([StrOutputParser]):::otherclass; + XMLOutputParser([XMLOutputParser]):::otherclass; + PromptInput --> PromptTemplate; + PromptTemplate --> FakeListLLM; + Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser; + CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output; + conditional_str_parser_input --> StrOutputParser; + StrOutputParser --> conditional_str_parser_output; + conditional_str_parser_input --> XMLOutputParser; + XMLOutputParser --> conditional_str_parser_output; + Parallel_as_list_as_str_Input --> conditional_str_parser_input; + conditional_str_parser_output --> Parallel_as_list_as_str_Output; + FakeListLLM --> Parallel_as_list_as_str_Input; + classDef startclass fill:#ffdfba; + classDef endclass fill:#baffc9; + classDef otherclass fill:#fad7de; + + ''' +# --- +# name: test_graph_single_runnable[ascii] ''' +----------------------+ | StrOutputParserInput | @@ -98,3 +148,18 @@ +-----------------------+ ''' # --- +# name: test_graph_single_runnable[mermaid] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + StrOutputParserInput[StrOutputParserInput]:::startclass; + StrOutputParser([StrOutputParser]):::otherclass; + StrOutputParserOutput[StrOutputParserOutput]:::endclass; + StrOutputParserInput --> StrOutputParser; + StrOutputParser --> StrOutputParserOutput; + classDef startclass fill:#ffdfba; + classDef endclass fill:#baffc9; + classDef otherclass fill:#fad7de; + + ''' +# --- diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 40361b2b919..3bee036acc5 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -21,7 +21,8 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: assert len(graph.edges) == 2 assert graph.edges[0].source == first_node.id assert graph.edges[1].target == last_node.id - assert graph.draw_ascii() == snapshot + assert graph.draw_ascii() == snapshot(name="ascii") + assert graph.draw_mermaid() == snapshot(name="mermaid") def test_graph_sequence(snapshot: SnapshotAssertion) -> None: @@ -88,7 +89,8 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None: {"source": 2, "target": 3}, ], } - assert graph.draw_ascii() == snapshot + assert graph.draw_ascii() == snapshot(name="ascii") + assert graph.draw_mermaid() == snapshot(name="mermaid") def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: @@ -482,4 +484,5 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: {"source": 2, "target": 3}, ], } - assert graph.draw_ascii() == snapshot + assert graph.draw_ascii() == snapshot(name="ascii") + assert graph.draw_mermaid() == snapshot(name="mermaid")