From 03fba07d157e0244f4bd19c22a73e950bbe9aa03 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 11 Jul 2024 14:19:36 -0700 Subject: [PATCH] core[patch]: Update styles for mermaid graphs (#24147) --- libs/core/langchain_core/runnables/graph.py | 18 +++--- .../langchain_core/runnables/graph_mermaid.py | 40 ++++++------- .../runnables/__snapshots__/test_graph.ambr | 56 +++++++++---------- 3 files changed, 56 insertions(+), 58 deletions(-) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 0298b3c72fd..48a68d58fd5 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -117,12 +117,12 @@ class CurveStyle(Enum): @dataclass -class NodeColors: +class NodeStyles: """Schema for Hexadecimal color codes for different node types""" - start: str = "#ffdfba" - end: str = "#baffc9" - other: str = "#fad7de" + default: str = "fill:#f2f0ff,line-height:1.2" + first: str = "fill-opacity:0" + last: str = "fill:#bfb6fc" class MermaidDrawMethod(Enum): @@ -447,9 +447,7 @@ class Graph: *, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, - node_colors: NodeColors = NodeColors( - start="#ffdfba", end="#baffc9", other="#fad7de" - ), + node_colors: NodeStyles = NodeStyles(), wrap_label_n_words: int = 9, ) -> str: from langchain_core.runnables.graph_mermaid import draw_mermaid @@ -465,7 +463,7 @@ class Graph: last_node=last_node.id if last_node else None, with_styles=with_styles, curve_style=curve_style, - node_colors=node_colors, + node_styles=node_colors, wrap_label_n_words=wrap_label_n_words, ) @@ -473,9 +471,7 @@ class Graph: self, *, curve_style: CurveStyle = CurveStyle.LINEAR, - node_colors: NodeColors = NodeColors( - start="#ffdfba", end="#baffc9", other="#fad7de" - ), + node_colors: NodeStyles = NodeStyles(), wrap_label_n_words: int = 9, output_file_path: Optional[str] = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 09741c07592..877cb109303 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -8,7 +8,7 @@ from langchain_core.runnables.graph import ( Edge, MermaidDrawMethod, Node, - NodeColors, + NodeStyles, ) @@ -20,7 +20,7 @@ def draw_mermaid( last_node: Optional[str] = None, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, - node_colors: NodeColors = NodeColors(), + node_styles: NodeStyles = NodeStyles(), wrap_label_n_words: int = 9, ) -> str: """Draws a Mermaid graph using the provided graph data @@ -49,23 +49,27 @@ def draw_mermaid( if with_styles: # Node formatting templates default_class_label = "default" - format_dict = {default_class_label: "{0}([{1}]):::otherclass"} + format_dict = {default_class_label: "{0}({1})"} if first_node is not None: - format_dict[first_node] = "{0}[{0}]:::startclass" + format_dict[first_node] = "{0}([{0}]):::first" if last_node is not None: - format_dict[last_node] = "{0}[{0}]:::endclass" + format_dict[last_node] = "{0}([{0}]):::last" # Add nodes to the graph for key, node in nodes.items(): label = node.name.split(":")[-1] if node.metadata: - label = f"{label}\n" + "\n".join( - f"{key} = {value}" for key, value in node.metadata.items() + label = ( + f"{label}
\n" + + "\n".join( + f"{key} = {value}" for key, value in node.metadata.items() + ) + + "" ) node_label = format_dict.get(key, format_dict[default_class_label]).format( _escape_node_label(key), label ) - mermaid_graph += f"\t{node_label};\n" + mermaid_graph += f"\t{node_label}\n" subgraph = "" # Add edges to the graph @@ -89,16 +93,14 @@ def draw_mermaid( words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: - edge_data = "
".join( - [ - " ".join(words[i : i + wrap_label_n_words]) - for i in range(0, len(words), wrap_label_n_words) - ] + edge_data = " 
 ".join( + " ".join(words[i : i + wrap_label_n_words]) + for i in range(0, len(words), wrap_label_n_words) ) if edge.conditional: - edge_label = f" -. {edge_data} .-> " + edge_label = f" -.  {edge_data}  .-> " else: - edge_label = f" -- {edge_data} --> " + edge_label = f" --  {edge_data}  --> " else: if edge.conditional: edge_label = " -.-> " @@ -113,7 +115,7 @@ def draw_mermaid( # Add custom styles for nodes if with_styles: - mermaid_graph += _generate_mermaid_graph_styles(node_colors) + mermaid_graph += _generate_mermaid_graph_styles(node_styles) return mermaid_graph @@ -122,11 +124,11 @@ def _escape_node_label(node_label: str) -> str: return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) -def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str: +def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str: """Generates Mermaid graph styles for different node types.""" styles = "" - for class_name, color in asdict(node_colors).items(): - styles += f"\tclassDef {class_name}class fill:{color};\n" + for class_name, style in asdict(node_colors).items(): + styles += f"\tclassDef {class_name} {style}\n" return styles diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 6d92219bb71..af617e3c11c 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -34,19 +34,19 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - PromptInput[PromptInput]:::startclass; - PromptTemplate([PromptTemplate]):::otherclass; - FakeListLLM([FakeListLLM - key = 2]):::otherclass; - CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; - CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; + PromptInput([PromptInput]):::first + PromptTemplate(PromptTemplate) + FakeListLLM(FakeListLLM
+ key = 2) + CommaSeparatedListOutputParser(CommaSeparatedListOutputParser) + CommaSeparatedListOutputParserOutput([CommaSeparatedListOutputParserOutput]):::last PromptInput --> PromptTemplate; PromptTemplate --> FakeListLLM; CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput; FakeListLLM --> CommaSeparatedListOutputParser; - classDef startclass fill:#ffdfba; - classDef endclass fill:#baffc9; - classDef otherclass fill:#fad7de; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc ''' # --- @@ -1011,16 +1011,16 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - PromptInput[PromptInput]:::startclass; - PromptTemplate([PromptTemplate]):::otherclass; - FakeListLLM([FakeListLLM]):::otherclass; - Parallel_as_list_as_str_Input([ParallelInput]):::otherclass; - Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass; - CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; - conditional_str_parser_input([conditional_str_parser_input]):::otherclass; - conditional_str_parser_output([conditional_str_parser_output]):::otherclass; - StrOutputParser([StrOutputParser]):::otherclass; - XMLOutputParser([XMLOutputParser]):::otherclass; + PromptInput([PromptInput]):::first + PromptTemplate(PromptTemplate) + FakeListLLM(FakeListLLM) + Parallel_as_list_as_str_Input(ParallelInput) + Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last + CommaSeparatedListOutputParser(CommaSeparatedListOutputParser) + conditional_str_parser_input(conditional_str_parser_input) + conditional_str_parser_output(conditional_str_parser_output) + StrOutputParser(StrOutputParser) + XMLOutputParser(XMLOutputParser) PromptInput --> PromptTemplate; PromptTemplate --> FakeListLLM; Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser; @@ -1032,9 +1032,9 @@ Parallel_as_list_as_str_Input --> conditional_str_parser_input; conditional_str_parser_output --> Parallel_as_list_as_str_Output; FakeListLLM --> Parallel_as_list_as_str_Input; - classDef startclass fill:#ffdfba; - classDef endclass fill:#baffc9; - classDef otherclass fill:#fad7de; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc ''' # --- @@ -1061,14 +1061,14 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - StrOutputParserInput[StrOutputParserInput]:::startclass; - StrOutputParser([StrOutputParser]):::otherclass; - StrOutputParserOutput[StrOutputParserOutput]:::endclass; + StrOutputParserInput([StrOutputParserInput]):::first + StrOutputParser(StrOutputParser) + StrOutputParserOutput([StrOutputParserOutput]):::last StrOutputParserInput --> StrOutputParser; StrOutputParser --> StrOutputParserOutput; - classDef startclass fill:#ffdfba; - classDef endclass fill:#baffc9; - classDef otherclass fill:#fad7de; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc ''' # ---