diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py
index 0298b3c72fd..48a68d58fd5 100644
--- a/libs/core/langchain_core/runnables/graph.py
+++ b/libs/core/langchain_core/runnables/graph.py
@@ -117,12 +117,12 @@ class CurveStyle(Enum):
@dataclass
-class NodeColors:
+class NodeStyles:
"""Schema for Hexadecimal color codes for different node types"""
- start: str = "#ffdfba"
- end: str = "#baffc9"
- other: str = "#fad7de"
+ default: str = "fill:#f2f0ff,line-height:1.2"
+ first: str = "fill-opacity:0"
+ last: str = "fill:#bfb6fc"
class MermaidDrawMethod(Enum):
@@ -447,9 +447,7 @@ class Graph:
*,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
- node_colors: NodeColors = NodeColors(
- start="#ffdfba", end="#baffc9", other="#fad7de"
- ),
+ node_colors: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid
@@ -465,7 +463,7 @@ class Graph:
last_node=last_node.id if last_node else None,
with_styles=with_styles,
curve_style=curve_style,
- node_colors=node_colors,
+ node_styles=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
@@ -473,9 +471,7 @@ class Graph:
self,
*,
curve_style: CurveStyle = CurveStyle.LINEAR,
- node_colors: NodeColors = NodeColors(
- start="#ffdfba", end="#baffc9", other="#fad7de"
- ),
+ node_colors: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py
index 09741c07592..877cb109303 100644
--- a/libs/core/langchain_core/runnables/graph_mermaid.py
+++ b/libs/core/langchain_core/runnables/graph_mermaid.py
@@ -8,7 +8,7 @@ from langchain_core.runnables.graph import (
Edge,
MermaidDrawMethod,
Node,
- NodeColors,
+ NodeStyles,
)
@@ -20,7 +20,7 @@ def draw_mermaid(
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
- node_colors: NodeColors = NodeColors(),
+ node_styles: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data
@@ -49,23 +49,27 @@ def draw_mermaid(
if with_styles:
# Node formatting templates
default_class_label = "default"
- format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
+ format_dict = {default_class_label: "{0}({1})"}
if first_node is not None:
- format_dict[first_node] = "{0}[{0}]:::startclass"
+ format_dict[first_node] = "{0}([{0}]):::first"
if last_node is not None:
- format_dict[last_node] = "{0}[{0}]:::endclass"
+ format_dict[last_node] = "{0}([{0}]):::last"
# Add nodes to the graph
for key, node in nodes.items():
label = node.name.split(":")[-1]
if node.metadata:
- label = f"{label}\n" + "\n".join(
- f"{key} = {value}" for key, value in node.metadata.items()
+ label = (
+ f"{label}
\n"
+ + "\n".join(
+ f"{key} = {value}" for key, value in node.metadata.items()
+ )
+ + ""
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
- mermaid_graph += f"\t{node_label};\n"
+ mermaid_graph += f"\t{node_label}\n"
subgraph = ""
# Add edges to the graph
@@ -89,16 +93,14 @@ def draw_mermaid(
words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
- edge_data = "
".join(
- [
- " ".join(words[i : i + wrap_label_n_words])
- for i in range(0, len(words), wrap_label_n_words)
- ]
+ edge_data = " 
 ".join(
+ " ".join(words[i : i + wrap_label_n_words])
+ for i in range(0, len(words), wrap_label_n_words)
)
if edge.conditional:
- edge_label = f" -. {edge_data} .-> "
+ edge_label = f" -.  {edge_data}  .-> "
else:
- edge_label = f" -- {edge_data} --> "
+ edge_label = f" --  {edge_data}  --> "
else:
if edge.conditional:
edge_label = " -.-> "
@@ -113,7 +115,7 @@ def draw_mermaid(
# Add custom styles for nodes
if with_styles:
- mermaid_graph += _generate_mermaid_graph_styles(node_colors)
+ mermaid_graph += _generate_mermaid_graph_styles(node_styles)
return mermaid_graph
@@ -122,11 +124,11 @@ def _escape_node_label(node_label: str) -> str:
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
-def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
+def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
"""Generates Mermaid graph styles for different node types."""
styles = ""
- for class_name, color in asdict(node_colors).items():
- styles += f"\tclassDef {class_name}class fill:{color};\n"
+ for class_name, style in asdict(node_colors).items():
+ styles += f"\tclassDef {class_name} {style}\n"
return styles
diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
index 6d92219bb71..af617e3c11c 100644
--- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
+++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
@@ -34,19 +34,19 @@
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
- PromptInput[PromptInput]:::startclass;
- PromptTemplate([PromptTemplate]):::otherclass;
- FakeListLLM([FakeListLLM
- key = 2]):::otherclass;
- CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
- CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
+ PromptInput([PromptInput]):::first
+ PromptTemplate(PromptTemplate)
+ FakeListLLM(FakeListLLM
+ key = 2)
+ CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
+ CommaSeparatedListOutputParserOutput([CommaSeparatedListOutputParserOutput]):::last
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
FakeListLLM --> CommaSeparatedListOutputParser;
- classDef startclass fill:#ffdfba;
- classDef endclass fill:#baffc9;
- classDef otherclass fill:#fad7de;
+ classDef default fill:#f2f0ff,line-height:1.2
+ classDef first fill-opacity:0
+ classDef last fill:#bfb6fc
'''
# ---
@@ -1011,16 +1011,16 @@
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
- PromptInput[PromptInput]:::startclass;
- PromptTemplate([PromptTemplate]):::otherclass;
- FakeListLLM([FakeListLLM]):::otherclass;
- Parallel_as_list_as_str_Input([ParallelInput]):::otherclass;
- Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass;
- CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
- conditional_str_parser_input([conditional_str_parser_input]):::otherclass;
- conditional_str_parser_output([conditional_str_parser_output]):::otherclass;
- StrOutputParser([StrOutputParser]):::otherclass;
- XMLOutputParser([XMLOutputParser]):::otherclass;
+ PromptInput([PromptInput]):::first
+ PromptTemplate(PromptTemplate)
+ FakeListLLM(FakeListLLM)
+ Parallel_as_list_as_str_Input(ParallelInput)
+ Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last
+ CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
+ conditional_str_parser_input(conditional_str_parser_input)
+ conditional_str_parser_output(conditional_str_parser_output)
+ StrOutputParser(StrOutputParser)
+ XMLOutputParser(XMLOutputParser)
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
@@ -1032,9 +1032,9 @@
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
FakeListLLM --> Parallel_as_list_as_str_Input;
- classDef startclass fill:#ffdfba;
- classDef endclass fill:#baffc9;
- classDef otherclass fill:#fad7de;
+ classDef default fill:#f2f0ff,line-height:1.2
+ classDef first fill-opacity:0
+ classDef last fill:#bfb6fc
'''
# ---
@@ -1061,14 +1061,14 @@
'''
%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
- StrOutputParserInput[StrOutputParserInput]:::startclass;
- StrOutputParser([StrOutputParser]):::otherclass;
- StrOutputParserOutput[StrOutputParserOutput]:::endclass;
+ StrOutputParserInput([StrOutputParserInput]):::first
+ StrOutputParser(StrOutputParser)
+ StrOutputParserOutput([StrOutputParserOutput]):::last
StrOutputParserInput --> StrOutputParser;
StrOutputParser --> StrOutputParserOutput;
- classDef startclass fill:#ffdfba;
- classDef endclass fill:#baffc9;
- classDef otherclass fill:#fad7de;
+ classDef default fill:#f2f0ff,line-height:1.2
+ classDef first fill-opacity:0
+ classDef last fill:#bfb6fc
'''
# ---