core[patch]: Update styles for mermaid graphs (#24147)

This commit is contained in:
Nuno Campos 2024-07-11 14:19:36 -07:00 committed by GitHub
parent c481a2715d
commit 03fba07d15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 58 deletions

View File

@ -117,12 +117,12 @@ class CurveStyle(Enum):
@dataclass @dataclass
class NodeColors: class NodeStyles:
"""Schema for Hexadecimal color codes for different node types""" """Schema for Hexadecimal color codes for different node types"""
start: str = "#ffdfba" default: str = "fill:#f2f0ff,line-height:1.2"
end: str = "#baffc9" first: str = "fill-opacity:0"
other: str = "#fad7de" last: str = "fill:#bfb6fc"
class MermaidDrawMethod(Enum): class MermaidDrawMethod(Enum):
@ -447,9 +447,7 @@ class Graph:
*, *,
with_styles: bool = True, with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors( node_colors: NodeStyles = NodeStyles(),
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
) -> str: ) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid from langchain_core.runnables.graph_mermaid import draw_mermaid
@ -465,7 +463,7 @@ class Graph:
last_node=last_node.id if last_node else None, last_node=last_node.id if last_node else None,
with_styles=with_styles, with_styles=with_styles,
curve_style=curve_style, curve_style=curve_style,
node_colors=node_colors, node_styles=node_colors,
wrap_label_n_words=wrap_label_n_words, wrap_label_n_words=wrap_label_n_words,
) )
@ -473,9 +471,7 @@ class Graph:
self, self,
*, *,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors( node_colors: NodeStyles = NodeStyles(),
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None, output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API, draw_method: MermaidDrawMethod = MermaidDrawMethod.API,

View File

@ -8,7 +8,7 @@ from langchain_core.runnables.graph import (
Edge, Edge,
MermaidDrawMethod, MermaidDrawMethod,
Node, Node,
NodeColors, NodeStyles,
) )
@ -20,7 +20,7 @@ def draw_mermaid(
last_node: Optional[str] = None, last_node: Optional[str] = None,
with_styles: bool = True, with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(), node_styles: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
) -> str: ) -> str:
"""Draws a Mermaid graph using the provided graph data """Draws a Mermaid graph using the provided graph data
@ -49,23 +49,27 @@ def draw_mermaid(
if with_styles: if with_styles:
# Node formatting templates # Node formatting templates
default_class_label = "default" default_class_label = "default"
format_dict = {default_class_label: "{0}([{1}]):::otherclass"} format_dict = {default_class_label: "{0}({1})"}
if first_node is not None: if first_node is not None:
format_dict[first_node] = "{0}[{0}]:::startclass" format_dict[first_node] = "{0}([{0}]):::first"
if last_node is not None: if last_node is not None:
format_dict[last_node] = "{0}[{0}]:::endclass" format_dict[last_node] = "{0}([{0}]):::last"
# Add nodes to the graph # Add nodes to the graph
for key, node in nodes.items(): for key, node in nodes.items():
label = node.name.split(":")[-1] label = node.name.split(":")[-1]
if node.metadata: if node.metadata:
label = f"<strong>{label}</strong>\n" + "\n".join( label = (
f"{key} = {value}" for key, value in node.metadata.items() f"{label}<hr/>\n<small><em>"
+ "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
+ "</em></small>"
) )
node_label = format_dict.get(key, format_dict[default_class_label]).format( node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label _escape_node_label(key), label
) )
mermaid_graph += f"\t{node_label};\n" mermaid_graph += f"\t{node_label}\n"
subgraph = "" subgraph = ""
# Add edges to the graph # Add edges to the graph
@ -89,16 +93,14 @@ def draw_mermaid(
words = str(edge_data).split() # Split the string into words words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size # Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words: if len(words) > wrap_label_n_words:
edge_data = "<br>".join( edge_data = "&nbsp<br>&nbsp".join(
[ " ".join(words[i : i + wrap_label_n_words])
" ".join(words[i : i + wrap_label_n_words]) for i in range(0, len(words), wrap_label_n_words)
for i in range(0, len(words), wrap_label_n_words)
]
) )
if edge.conditional: if edge.conditional:
edge_label = f" -. {edge_data} .-> " edge_label = f" -. &nbsp{edge_data}&nbsp .-> "
else: else:
edge_label = f" -- {edge_data} --> " edge_label = f" -- &nbsp{edge_data}&nbsp --> "
else: else:
if edge.conditional: if edge.conditional:
edge_label = " -.-> " edge_label = " -.-> "
@ -113,7 +115,7 @@ def draw_mermaid(
# Add custom styles for nodes # Add custom styles for nodes
if with_styles: if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_colors) mermaid_graph += _generate_mermaid_graph_styles(node_styles)
return mermaid_graph return mermaid_graph
@ -122,11 +124,11 @@ def _escape_node_label(node_label: str) -> str:
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str: def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
"""Generates Mermaid graph styles for different node types.""" """Generates Mermaid graph styles for different node types."""
styles = "" styles = ""
for class_name, color in asdict(node_colors).items(): for class_name, style in asdict(node_colors).items():
styles += f"\tclassDef {class_name}class fill:{color};\n" styles += f"\tclassDef {class_name} {style}\n"
return styles return styles

View File

@ -34,19 +34,19 @@
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD; graph TD;
PromptInput[PromptInput]:::startclass; PromptInput([PromptInput]):::first
PromptTemplate([PromptTemplate]):::otherclass; PromptTemplate(PromptTemplate)
FakeListLLM([<strong>FakeListLLM</strong> FakeListLLM(FakeListLLM<hr/>
key = 2]):::otherclass; <small><em>key = 2</em></small>)
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; CommaSeparatedListOutputParserOutput([CommaSeparatedListOutputParserOutput]):::last
PromptInput --> PromptTemplate; PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM; PromptTemplate --> FakeListLLM;
CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput; CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
FakeListLLM --> CommaSeparatedListOutputParser; FakeListLLM --> CommaSeparatedListOutputParser;
classDef startclass fill:#ffdfba; classDef default fill:#f2f0ff,line-height:1.2
classDef endclass fill:#baffc9; classDef first fill-opacity:0
classDef otherclass fill:#fad7de; classDef last fill:#bfb6fc
''' '''
# --- # ---
@ -1011,16 +1011,16 @@
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD; graph TD;
PromptInput[PromptInput]:::startclass; PromptInput([PromptInput]):::first
PromptTemplate([PromptTemplate]):::otherclass; PromptTemplate(PromptTemplate)
FakeListLLM([FakeListLLM]):::otherclass; FakeListLLM(FakeListLLM)
Parallel_as_list_as_str_Input([Parallel<as_list,as_str>Input]):::otherclass; Parallel_as_list_as_str_Input(Parallel<as_list,as_str>Input)
Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass; Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
conditional_str_parser_input([conditional_str_parser_input]):::otherclass; conditional_str_parser_input(conditional_str_parser_input)
conditional_str_parser_output([conditional_str_parser_output]):::otherclass; conditional_str_parser_output(conditional_str_parser_output)
StrOutputParser([StrOutputParser]):::otherclass; StrOutputParser(StrOutputParser)
XMLOutputParser([XMLOutputParser]):::otherclass; XMLOutputParser(XMLOutputParser)
PromptInput --> PromptTemplate; PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM; PromptTemplate --> FakeListLLM;
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser; Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
@ -1032,9 +1032,9 @@
Parallel_as_list_as_str_Input --> conditional_str_parser_input; Parallel_as_list_as_str_Input --> conditional_str_parser_input;
conditional_str_parser_output --> Parallel_as_list_as_str_Output; conditional_str_parser_output --> Parallel_as_list_as_str_Output;
FakeListLLM --> Parallel_as_list_as_str_Input; FakeListLLM --> Parallel_as_list_as_str_Input;
classDef startclass fill:#ffdfba; classDef default fill:#f2f0ff,line-height:1.2
classDef endclass fill:#baffc9; classDef first fill-opacity:0
classDef otherclass fill:#fad7de; classDef last fill:#bfb6fc
''' '''
# --- # ---
@ -1061,14 +1061,14 @@
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD; graph TD;
StrOutputParserInput[StrOutputParserInput]:::startclass; StrOutputParserInput([StrOutputParserInput]):::first
StrOutputParser([StrOutputParser]):::otherclass; StrOutputParser(StrOutputParser)
StrOutputParserOutput[StrOutputParserOutput]:::endclass; StrOutputParserOutput([StrOutputParserOutput]):::last
StrOutputParserInput --> StrOutputParser; StrOutputParserInput --> StrOutputParser;
StrOutputParser --> StrOutputParserOutput; StrOutputParser --> StrOutputParserOutput;
classDef startclass fill:#ffdfba; classDef default fill:#f2f0ff,line-height:1.2
classDef endclass fill:#baffc9; classDef first fill-opacity:0
classDef otherclass fill:#fad7de; classDef last fill:#bfb6fc
''' '''
# --- # ---