Runnable graph viz improvements (#20529)

- Add conditional: bool property to json representation of the graphs
- Add option to generate mermaid graph stripped of styles (useful as a
text representation of graph)
This commit is contained in:
Nuno Campos 2024-04-16 13:17:47 -07:00 committed by GitHub
parent f3aa26d6bf
commit 806a54908c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 29 deletions

View File

@ -170,6 +170,17 @@ class Graph:
node.id: i if is_uuid(node.id) else node.id node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values()) for i, node in enumerate(self.nodes.values())
} }
edges: List[Dict[str, Any]] = []
for edge in self.edges:
edge_dict = {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
if edge.data is not None:
edge_dict["data"] = edge.data
if edge.conditional:
edge_dict["conditional"] = True
edges.append(edge_dict)
return { return {
"nodes": [ "nodes": [
@ -179,19 +190,7 @@ class Graph:
} }
for node in self.nodes.values() for node in self.nodes.values()
], ],
"edges": [ "edges": edges,
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
} }
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -345,6 +344,7 @@ class Graph:
def draw_mermaid( def draw_mermaid(
self, self,
*, *,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors( node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de" start="#ffdfba", end="#baffc9", other="#fad7de"
@ -366,6 +366,7 @@ class Graph:
edges=self.edges, edges=self.edges,
first_node_label=first_label, first_node_label=first_label,
last_node_label=last_label, last_node_label=last_label,
with_styles=with_styles,
curve_style=curve_style, curve_style=curve_style,
node_colors=node_colors, node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words, wrap_label_n_words=wrap_label_n_words,

View File

@ -17,6 +17,7 @@ def draw_mermaid(
*, *,
first_node_label: Optional[str] = None, first_node_label: Optional[str] = None,
last_node_label: Optional[str] = None, last_node_label: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR, curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(), node_colors: NodeColors = NodeColors(),
wrap_label_n_words: int = 9, wrap_label_n_words: int = 9,
@ -36,10 +37,15 @@ def draw_mermaid(
""" """
# Initialize Mermaid graph configuration # Initialize Mermaid graph configuration
mermaid_graph = ( mermaid_graph = (
(
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
f"}}}}}}%%\ngraph TD;\n" f"}}}}}}%%\ngraph TD;\n"
) )
if with_styles
else "graph TD;\n"
)
if with_styles:
# Node formatting templates # Node formatting templates
default_class_label = "default" default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"} format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
@ -92,6 +98,7 @@ def draw_mermaid(
) )
# Add custom styles for nodes # Add custom styles for nodes
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_colors) mermaid_graph += _generate_mermaid_graph_styles(node_colors)
return mermaid_graph return mermaid_graph

View File

@ -98,6 +98,23 @@
+--------------------------------+ +--------------------------------+
''' '''
# --- # ---
# name: test_graph_sequence_map[mermaid-simple]
'''
graph TD;
PromptInput --> PromptTemplate;
PromptTemplate --> FakeListLLM;
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output;
conditional_str_parser_input --> StrOutputParser;
StrOutputParser --> conditional_str_parser_output;
conditional_str_parser_input --> XMLOutputParser;
XMLOutputParser --> conditional_str_parser_output;
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
FakeListLLM --> Parallel_as_list_as_str_Input;
'''
# ---
# name: test_graph_sequence_map[mermaid] # name: test_graph_sequence_map[mermaid]
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%

View File

@ -660,3 +660,4 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
} }
assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid") assert graph.draw_mermaid() == snapshot(name="mermaid")
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")