mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Runnable graph viz improvements (#20529)
- Add conditional: bool property to json representation of the graphs - Add option to generate mermaid graph stripped of styles (useful as a text representation of graph)
This commit is contained in:
parent
f3aa26d6bf
commit
806a54908c
@ -170,6 +170,17 @@ class Graph:
|
|||||||
node.id: i if is_uuid(node.id) else node.id
|
node.id: i if is_uuid(node.id) else node.id
|
||||||
for i, node in enumerate(self.nodes.values())
|
for i, node in enumerate(self.nodes.values())
|
||||||
}
|
}
|
||||||
|
edges: List[Dict[str, Any]] = []
|
||||||
|
for edge in self.edges:
|
||||||
|
edge_dict = {
|
||||||
|
"source": stable_node_ids[edge.source],
|
||||||
|
"target": stable_node_ids[edge.target],
|
||||||
|
}
|
||||||
|
if edge.data is not None:
|
||||||
|
edge_dict["data"] = edge.data
|
||||||
|
if edge.conditional:
|
||||||
|
edge_dict["conditional"] = True
|
||||||
|
edges.append(edge_dict)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"nodes": [
|
"nodes": [
|
||||||
@ -179,19 +190,7 @@ class Graph:
|
|||||||
}
|
}
|
||||||
for node in self.nodes.values()
|
for node in self.nodes.values()
|
||||||
],
|
],
|
||||||
"edges": [
|
"edges": edges,
|
||||||
{
|
|
||||||
"source": stable_node_ids[edge.source],
|
|
||||||
"target": stable_node_ids[edge.target],
|
|
||||||
"data": edge.data,
|
|
||||||
}
|
|
||||||
if edge.data is not None
|
|
||||||
else {
|
|
||||||
"source": stable_node_ids[edge.source],
|
|
||||||
"target": stable_node_ids[edge.target],
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
@ -345,6 +344,7 @@ class Graph:
|
|||||||
def draw_mermaid(
|
def draw_mermaid(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
with_styles: bool = True,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: NodeColors = NodeColors(
|
node_colors: NodeColors = NodeColors(
|
||||||
start="#ffdfba", end="#baffc9", other="#fad7de"
|
start="#ffdfba", end="#baffc9", other="#fad7de"
|
||||||
@ -366,6 +366,7 @@ class Graph:
|
|||||||
edges=self.edges,
|
edges=self.edges,
|
||||||
first_node_label=first_label,
|
first_node_label=first_label,
|
||||||
last_node_label=last_label,
|
last_node_label=last_label,
|
||||||
|
with_styles=with_styles,
|
||||||
curve_style=curve_style,
|
curve_style=curve_style,
|
||||||
node_colors=node_colors,
|
node_colors=node_colors,
|
||||||
wrap_label_n_words=wrap_label_n_words,
|
wrap_label_n_words=wrap_label_n_words,
|
||||||
|
@ -17,6 +17,7 @@ def draw_mermaid(
|
|||||||
*,
|
*,
|
||||||
first_node_label: Optional[str] = None,
|
first_node_label: Optional[str] = None,
|
||||||
last_node_label: Optional[str] = None,
|
last_node_label: Optional[str] = None,
|
||||||
|
with_styles: bool = True,
|
||||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||||
node_colors: NodeColors = NodeColors(),
|
node_colors: NodeColors = NodeColors(),
|
||||||
wrap_label_n_words: int = 9,
|
wrap_label_n_words: int = 9,
|
||||||
@ -36,10 +37,15 @@ def draw_mermaid(
|
|||||||
"""
|
"""
|
||||||
# Initialize Mermaid graph configuration
|
# Initialize Mermaid graph configuration
|
||||||
mermaid_graph = (
|
mermaid_graph = (
|
||||||
|
(
|
||||||
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
|
f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"
|
||||||
f"}}}}}}%%\ngraph TD;\n"
|
f"}}}}}}%%\ngraph TD;\n"
|
||||||
)
|
)
|
||||||
|
if with_styles
|
||||||
|
else "graph TD;\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if with_styles:
|
||||||
# Node formatting templates
|
# Node formatting templates
|
||||||
default_class_label = "default"
|
default_class_label = "default"
|
||||||
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
|
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
|
||||||
@ -92,6 +98,7 @@ def draw_mermaid(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add custom styles for nodes
|
# Add custom styles for nodes
|
||||||
|
if with_styles:
|
||||||
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
|
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
|
||||||
return mermaid_graph
|
return mermaid_graph
|
||||||
|
|
||||||
|
@ -98,6 +98,23 @@
|
|||||||
+--------------------------------+
|
+--------------------------------+
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_graph_sequence_map[mermaid-simple]
|
||||||
|
'''
|
||||||
|
graph TD;
|
||||||
|
PromptInput --> PromptTemplate;
|
||||||
|
PromptTemplate --> FakeListLLM;
|
||||||
|
Parallel_as_list_as_str_Input --> CommaSeparatedListOutputParser;
|
||||||
|
CommaSeparatedListOutputParser --> Parallel_as_list_as_str_Output;
|
||||||
|
conditional_str_parser_input --> StrOutputParser;
|
||||||
|
StrOutputParser --> conditional_str_parser_output;
|
||||||
|
conditional_str_parser_input --> XMLOutputParser;
|
||||||
|
XMLOutputParser --> conditional_str_parser_output;
|
||||||
|
Parallel_as_list_as_str_Input --> conditional_str_parser_input;
|
||||||
|
conditional_str_parser_output --> Parallel_as_list_as_str_Output;
|
||||||
|
FakeListLLM --> Parallel_as_list_as_str_Input;
|
||||||
|
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
# name: test_graph_sequence_map[mermaid]
|
# name: test_graph_sequence_map[mermaid]
|
||||||
'''
|
'''
|
||||||
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
%%{init: {'flowchart': {'curve': 'linear'}}}%%
|
||||||
|
@ -660,3 +660,4 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
|||||||
}
|
}
|
||||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||||
|
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid-simple")
|
||||||
|
Loading…
Reference in New Issue
Block a user