mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
core[patch]: Update styles for mermaid graphs (#24147)
This commit is contained in:
parent
c481a2715d
commit
03fba07d15
@ -117,12 +117,12 @@ class CurveStyle(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeColors:
|
||||
class NodeStyles:
|
||||
"""Schema for Hexadecimal color codes for different node types"""
|
||||
|
||||
start: str = "#ffdfba"
|
||||
end: str = "#baffc9"
|
||||
other: str = "#fad7de"
|
||||
default: str = "fill:#f2f0ff,line-height:1.2"
|
||||
first: str = "fill-opacity:0"
|
||||
last: str = "fill:#bfb6fc"
|
||||
|
||||
|
||||
class MermaidDrawMethod(Enum):
|
||||
@ -447,9 +447,7 @@ class Graph:
|
||||
*,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(
|
||||
start="#ffdfba", end="#baffc9", other="#fad7de"
|
||||
),
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||
@ -465,7 +463,7 @@ class Graph:
|
||||
last_node=last_node.id if last_node else None,
|
||||
with_styles=with_styles,
|
||||
curve_style=curve_style,
|
||||
node_colors=node_colors,
|
||||
node_styles=node_colors,
|
||||
wrap_label_n_words=wrap_label_n_words,
|
||||
)
|
||||
|
||||
@ -473,9 +471,7 @@ class Graph:
|
||||
self,
|
||||
*,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(
|
||||
start="#ffdfba", end="#baffc9", other="#fad7de"
|
||||
),
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
output_file_path: Optional[str] = None,
|
||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||
|
@ -8,7 +8,7 @@ from langchain_core.runnables.graph import (
|
||||
Edge,
|
||||
MermaidDrawMethod,
|
||||
Node,
|
||||
NodeColors,
|
||||
NodeStyles,
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ def draw_mermaid(
|
||||
last_node: Optional[str] = None,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(),
|
||||
node_styles: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data
|
||||
@ -49,23 +49,27 @@ def draw_mermaid(
|
||||
if with_styles:
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||
format_dict = {default_class_label: "{0}({1})"}
|
||||
if first_node is not None:
|
||||
format_dict[first_node] = "{0}[{0}]:::startclass"
|
||||
format_dict[first_node] = "{0}([{0}]):::first"
|
||||
if last_node is not None:
|
||||
format_dict[last_node] = "{0}[{0}]:::endclass"
|
||||
format_dict[last_node] = "{0}([{0}]):::last"
|
||||
|
||||
# Add nodes to the graph
|
||||
for key, node in nodes.items():
|
||||
label = node.name.split(":")[-1]
|
||||
if node.metadata:
|
||||
label = f"<strong>{label}</strong>\n" + "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
label = (
|
||||
f"{label}<hr/>\n<small><em>"
|
||||
+ "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
)
|
||||
+ "</em></small>"
|
||||
)
|
||||
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||
_escape_node_label(key), label
|
||||
)
|
||||
mermaid_graph += f"\t{node_label};\n"
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
|
||||
subgraph = ""
|
||||
# Add edges to the graph
|
||||
@ -89,16 +93,14 @@ def draw_mermaid(
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = "<br>".join(
|
||||
[
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
]
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -. {edge_data} .-> "
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" -- {edge_data} --> "
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
@ -113,7 +115,7 @@ def draw_mermaid(
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
|
||||
return mermaid_graph
|
||||
|
||||
|
||||
@ -122,11 +124,11 @@ def _escape_node_label(node_label: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
||||
|
||||
|
||||
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
|
||||
def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
|
||||
"""Generates Mermaid graph styles for different node types."""
|
||||
styles = ""
|
||||
for class_name, color in asdict(node_colors).items():
|
||||
styles += f"\tclassDef {class_name}class fill:{color};\n"
|
||||
for class_name, style in asdict(node_colors).items():
|
||||
styles += f"\tclassDef {class_name} {style}\n"
|
||||
return styles
|
||||
|
||||
|
||||
|
@ -34,19 +34,19 @@
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
PromptInput[PromptInput]:::startclass;
|
||||
PromptTemplate([PromptTemplate]):::otherclass;
|
||||
FakeListLLM([<strong>FakeListLLM</strong>
|
||||
key = 2]):::otherclass;
|
||||
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
|
||||
CommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
|
||||
PromptInput([PromptInput]):::first
|
||||
PromptTemplate(PromptTemplate)
|
||||
FakeListLLM(FakeListLLM<hr/>
|
||||
<small><em>key = 2</em></small>)
|
||||
CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
|
||||
CommaSeparatedListOutputParserOutput([CommaSeparatedListOutputParserOutput]):::last
|
||||
PromptInput --> PromptTemplate;
|
||||
PromptTemplate --> FakeListLLM;
|
||||
CommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
|
||||
FakeListLLM --> CommaSeparatedListOutputParser;
|
||||
classDef startclass fill:#ffdfba;
|
||||
classDef endclass fill:#baffc9;
|
||||
classDef otherclass fill:#fad7de;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@ -1011,16 +1011,16 @@
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
PromptInput[PromptInput]:::startclass;
|
||||
PromptTemplate([PromptTemplate]):::otherclass;
|
||||
FakeListLLM([FakeListLLM]):::otherclass;
|
||||
Parallel_as_list_as_str_Input([Parallel<as_list,as_str>Input]):::otherclass;
|
||||
Parallel_as_list_as_str_Output[Parallel_as_list_as_str_Output]:::endclass;
|
||||
CommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
|
||||
conditional_str_parser_input([conditional_str_parser_input]):::otherclass;
|
||||
conditional_str_parser_output([conditional_str_parser_output]):::otherclass;
|
||||
StrOutputParser([StrOutputParser]):::otherclass;
|
||||
XMLOutputParser([XMLOutputParser]):::otherclass;
|
||||
PromptInput([PromptInput]):::first
|
||||
PromptTemplate(PromptTemplate)
|
||||
FakeListLLM(FakeListLLM)
|
||||
Parallel_as_list_as_str_Input(Parallel<as_list,as_str>Input)
|
||||
Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last
|
||||
CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
|
||||
conditional_str_parser_input(conditional_str_parser_input)
|
||||
conditional_str_parser_output(conditional_str_parser_output)
|
||||
StrOutputParser(StrOutputParser)
|
||||
XMLOutputParser(XMLOutputParser)
|
||||
PromptInput --> PromptTemplate;
|
||||
PromptTemplate --> FakeListLLM;
|
||||
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
|
||||
@ -1032,9 +1032,9 @@
|
||||
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
|
||||
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
|
||||
FakeListLLM --> Parallel_as_list_as_str_Input;
|
||||
classDef startclass fill:#ffdfba;
|
||||
classDef endclass fill:#baffc9;
|
||||
classDef otherclass fill:#fad7de;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
@ -1061,14 +1061,14 @@
|
||||
'''
|
||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||
graph TD;
|
||||
StrOutputParserInput[StrOutputParserInput]:::startclass;
|
||||
StrOutputParser([StrOutputParser]):::otherclass;
|
||||
StrOutputParserOutput[StrOutputParserOutput]:::endclass;
|
||||
StrOutputParserInput([StrOutputParserInput]):::first
|
||||
StrOutputParser(StrOutputParser)
|
||||
StrOutputParserOutput([StrOutputParserOutput]):::last
|
||||
StrOutputParserInput --> StrOutputParser;
|
||||
StrOutputParser --> StrOutputParserOutput;
|
||||
classDef startclass fill:#ffdfba;
|
||||
classDef endclass fill:#baffc9;
|
||||
classDef otherclass fill:#fad7de;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
|
Loading…
Reference in New Issue
Block a user