core[patch]: wrap mermaid node names w/ markdown in <p> tag (#26235)

This fixes the issue where `__start__` and `__end__` node labels are
being interpreted as markdown, as of the most recent Mermaid update
This commit is contained in:
Vadym Barda 2024-09-09 20:11:00 -04:00 committed by GitHub
parent 3e48c728d5
commit bab9de581c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 8 deletions

View File

@ -11,6 +11,8 @@ from langchain_core.runnables.graph import (
NodeStyles, NodeStyles,
) )
MARKDOWN_SPECIAL_CHARS = "*_`"
def draw_mermaid( def draw_mermaid(
nodes: Dict[str, Node], nodes: Dict[str, Node],
@ -58,13 +60,19 @@ def draw_mermaid(
default_class_label = "default" default_class_label = "default"
format_dict = {default_class_label: "{0}({1})"} format_dict = {default_class_label: "{0}({1})"}
if first_node is not None: if first_node is not None:
format_dict[first_node] = "{0}([{0}]):::first" format_dict[first_node] = "{0}([{1}]):::first"
if last_node is not None: if last_node is not None:
format_dict[last_node] = "{0}([{0}]):::last" format_dict[last_node] = "{0}([{1}]):::last"
# Add nodes to the graph # Add nodes to the graph
for key, node in nodes.items(): for key, node in nodes.items():
label = node.name.split(":")[-1] node_name = node.name.split(":")[-1]
label = (
f"<p>{node_name}</p>"
if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))
and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))
else node_name
)
if node.metadata: if node.metadata:
label = ( label = (
f"{label}<hr/><small><em>" f"{label}<hr/><small><em>"

View File

@ -1040,7 +1040,7 @@
PromptTemplate(PromptTemplate) PromptTemplate(PromptTemplate)
FakeListLLM(FakeListLLM) FakeListLLM(FakeListLLM)
Parallel_as_list_as_str_Input(Parallel<as_list,as_str>Input) Parallel_as_list_as_str_Input(Parallel<as_list,as_str>Input)
Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last Parallel_as_list_as_str_Output([Parallel<as_list,as_str>Output]):::last
CommaSeparatedListOutputParser(CommaSeparatedListOutputParser) CommaSeparatedListOutputParser(CommaSeparatedListOutputParser)
conditional_str_parser_input(conditional_str_parser_input) conditional_str_parser_input(conditional_str_parser_input)
conditional_str_parser_output(conditional_str_parser_output) conditional_str_parser_output(conditional_str_parser_output)
@ -1067,14 +1067,14 @@
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD; graph TD;
__start__([__start__]):::first __start__([<p>__start__</p>]):::first
outer_1(outer_1) outer_1(outer_1)
inner_1_inner_1(inner_1) inner_1_inner_1(inner_1)
inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>) inner_1_inner_2(inner_2<hr/><small><em>__interrupt = before</em></small>)
inner_2_inner_1(inner_1) inner_2_inner_1(inner_1)
inner_2_inner_2(inner_2) inner_2_inner_2(inner_2)
outer_2(outer_2) outer_2(outer_2)
__end__([__end__]):::last __end__([<p>__end__</p>]):::last
__start__ --> outer_1; __start__ --> outer_1;
inner_1_inner_2 --> outer_2; inner_1_inner_2 --> outer_2;
inner_2_inner_2 --> outer_2; inner_2_inner_2 --> outer_2;
@ -1097,13 +1097,13 @@
''' '''
%%{init: {'flowchart': {'curve': 'linear'}}}%% %%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD; graph TD;
__start__([__start__]):::first __start__([<p>__start__</p>]):::first
parent_1(parent_1) parent_1(parent_1)
child_child_1_grandchild_1(grandchild_1) child_child_1_grandchild_1(grandchild_1)
child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>) child_child_1_grandchild_2(grandchild_2<hr/><small><em>__interrupt = before</em></small>)
child_child_2(child_2) child_child_2(child_2)
parent_2(parent_2) parent_2(parent_2)
__end__([__end__]):::last __end__([<p>__end__</p>]):::last
__start__ --> parent_1; __start__ --> parent_1;
child_child_2 --> parent_2; child_child_2 --> parent_2;
parent_1 --> child_child_1_grandchild_1; parent_1 --> child_child_1_grandchild_1;