From 806a54908cfcf16be0d9d3f9a8e04068e2d63062 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 16 Apr 2024 13:17:47 -0700 Subject: [PATCH] Runnable graph viz improvements (#20529) - Add conditional: bool property to json representation of the graphs - Add option to generate mermaid graph stripped of styles (useful as a text representation of graph) --- libs/core/langchain_core/runnables/graph.py | 27 ++++++------- .../langchain_core/runnables/graph_mermaid.py | 39 +++++++++++-------- .../runnables/__snapshots__/test_graph.ambr | 17 ++++++++ .../tests/unit_tests/runnables/test_graph.py | 1 + 4 files changed, 55 insertions(+), 29 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 97fbf00ed9a..f92b4b064a2 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -170,6 +170,17 @@ class Graph: node.id: i if is_uuid(node.id) else node.id for i, node in enumerate(self.nodes.values()) } + edges: List[Dict[str, Any]] = [] + for edge in self.edges: + edge_dict = { + "source": stable_node_ids[edge.source], + "target": stable_node_ids[edge.target], + } + if edge.data is not None: + edge_dict["data"] = edge.data + if edge.conditional: + edge_dict["conditional"] = True + edges.append(edge_dict) return { "nodes": [ @@ -179,19 +190,7 @@ class Graph: } for node in self.nodes.values() ], - "edges": [ - { - "source": stable_node_ids[edge.source], - "target": stable_node_ids[edge.target], - "data": edge.data, - } - if edge.data is not None - else { - "source": stable_node_ids[edge.source], - "target": stable_node_ids[edge.target], - } - for edge in self.edges - ], + "edges": edges, } def __bool__(self) -> bool: @@ -345,6 +344,7 @@ class Graph: def draw_mermaid( self, *, + with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" @@ -366,6 +366,7 @@ class Graph: edges=self.edges, first_node_label=first_label, last_node_label=last_label, + with_styles=with_styles, curve_style=curve_style, node_colors=node_colors, wrap_label_n_words=wrap_label_n_words, diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 122020a47d6..93e052d8919 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -17,6 +17,7 @@ def draw_mermaid( *, first_node_label: Optional[str] = None, last_node_label: Optional[str] = None, + with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors(), wrap_label_n_words: int = 9, @@ -36,24 +37,29 @@ def draw_mermaid( """ # Initialize Mermaid graph configuration mermaid_graph = ( - f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" - f"}}}}}}%%\ngraph TD;\n" + ( + f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" + f"}}}}}}%%\ngraph TD;\n" + ) + if with_styles + else "graph TD;\n" ) - # Node formatting templates - default_class_label = "default" - format_dict = {default_class_label: "{0}([{0}]):::otherclass"} - if first_node_label is not None: - format_dict[first_node_label] = "{0}[{0}]:::startclass" - if last_node_label is not None: - format_dict[last_node_label] = "{0}[{0}]:::endclass" + if with_styles: + # Node formatting templates + default_class_label = "default" + format_dict = {default_class_label: "{0}([{0}]):::otherclass"} + if first_node_label is not None: + format_dict[first_node_label] = "{0}[{0}]:::startclass" + if last_node_label is not None: + format_dict[last_node_label] = "{0}[{0}]:::endclass" - # Add nodes to the graph - for node in nodes.values(): - node_label = format_dict.get(node, format_dict[default_class_label]).format( - _escape_node_label(node) - ) - mermaid_graph += f"\t{node_label};\n" + # Add nodes to the graph + for node in nodes.values(): + node_label = format_dict.get(node, format_dict[default_class_label]).format( + _escape_node_label(node) + ) + mermaid_graph += f"\t{node_label};\n" # Add edges to the graph for edge in edges: @@ -92,7 +98,8 @@ def draw_mermaid( ) # Add custom styles for nodes - mermaid_graph += _generate_mermaid_graph_styles(node_colors) + if with_styles: + mermaid_graph += _generate_mermaid_graph_styles(node_colors) return mermaid_graph diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index b887d8d1f14..f715157c5b5 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -98,6 +98,23 @@ +--------------------------------+ ''' # --- +# name: test_graph_sequence_map[mermaid-simple] + ''' + graph TD; + PromptInput --> PromptTemplate; + PromptTemplate --> FakeListLLM; + Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser; + CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output; + conditional_str_parser_input --> StrOutputParser; + StrOutputParser --> conditional_str_parser_output; + conditional_str_parser_input --> XMLOutputParser; + XMLOutputParser --> conditional_str_parser_output; + Parallel_as_list_as_str_Input --> conditional_str_parser_input; + conditional_str_parser_output --> Parallel_as_list_as_str_Output; + FakeListLLM --> Parallel_as_list_as_str_Input; + + ''' +# --- # name: test_graph_sequence_map[mermaid] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index fe98da71a72..8a9aa12c004 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -660,3 +660,4 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: } assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_mermaid() == snapshot(name="mermaid") + assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")