diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 1d06095b7b7..1e05a5af610 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -11,6 +11,8 @@ from langchain_core.runnables.graph import ( NodeStyles, ) +MARKDOWN_SPECIAL_CHARS = "*_`" + def draw_mermaid( nodes: Dict[str, Node], @@ -58,13 +60,19 @@ def draw_mermaid( default_class_label = "default" format_dict = {default_class_label: "{0}({1})"} if first_node is not None: - format_dict[first_node] = "{0}([{0}]):::first" + format_dict[first_node] = "{0}([{1}]):::first" if last_node is not None: - format_dict[last_node] = "{0}([{0}]):::last" + format_dict[last_node] = "{0}([{1}]):::last" # Add nodes to the graph for key, node in nodes.items(): - label = node.name.split(":")[-1] + node_name = node.name.split(":")[-1] + label = ( + f"

{node_name}

" + if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) + and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) + else node_name + ) if node.metadata: label = ( f"{label}
" diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index d1f1b20a49c..399f66aef09 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1040,7 +1040,7 @@ PromptTemplate(PromptTemplate) FakeListLLM(FakeListLLM) Parallel_as_list_as_str_Input(ParallelInput) - Parallel_as_list_as_str_Output([Parallel_as_list_as_str_Output]):::last + Parallel_as_list_as_str_Output([ParallelOutput]):::last CommaSeparatedListOutputParser(CommaSeparatedListOutputParser) conditional_str_parser_input(conditional_str_parser_input) conditional_str_parser_output(conditional_str_parser_output) @@ -1067,14 +1067,14 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - __start__([__start__]):::first + __start__([

__start__

]):::first outer_1(outer_1) inner_1_inner_1(inner_1) inner_1_inner_2(inner_2
__interrupt = before) inner_2_inner_1(inner_1) inner_2_inner_2(inner_2) outer_2(outer_2) - __end__([__end__]):::last + __end__([

__end__

]):::last __start__ --> outer_1; inner_1_inner_2 --> outer_2; inner_2_inner_2 --> outer_2; @@ -1097,13 +1097,13 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - __start__([__start__]):::first + __start__([

__start__

]):::first parent_1(parent_1) child_child_1_grandchild_1(grandchild_1) child_child_1_grandchild_2(grandchild_2
__interrupt = before) child_child_2(child_2) parent_2(parent_2) - __end__([__end__]):::last + __end__([

__end__

]):::last __start__ --> parent_1; child_child_2 --> parent_2; parent_1 --> child_child_1_grandchild_1;