diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 1d06095b7b7..1e05a5af610 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -11,6 +11,8 @@ from langchain_core.runnables.graph import ( NodeStyles, ) +MARKDOWN_SPECIAL_CHARS = "*_`" + def draw_mermaid( nodes: Dict[str, Node], @@ -58,13 +60,19 @@ def draw_mermaid( default_class_label = "default" format_dict = {default_class_label: "{0}({1})"} if first_node is not None: - format_dict[first_node] = "{0}([{0}]):::first" + format_dict[first_node] = "{0}([{1}]):::first" if last_node is not None: - format_dict[last_node] = "{0}([{0}]):::last" + format_dict[last_node] = "{0}([{1}]):::last" # Add nodes to the graph for key, node in nodes.items(): - label = node.name.split(":")[-1] + node_name = node.name.split(":")[-1] + label = ( + f"
{node_name}
" + if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) + and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) + else node_name + ) if node.metadata: label = ( f"{label}__start__
]):::first outer_1(outer_1) inner_1_inner_1(inner_1) inner_1_inner_2(inner_2__end__
]):::last __start__ --> outer_1; inner_1_inner_2 --> outer_2; inner_2_inner_2 --> outer_2; @@ -1097,13 +1097,13 @@ ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% graph TD; - __start__([__start__]):::first + __start__([__start__
]):::first parent_1(parent_1) child_child_1_grandchild_1(grandchild_1) child_child_1_grandchild_2(grandchild_2__end__
]):::last __start__ --> parent_1; child_child_2 --> parent_2; parent_1 --> child_child_1_grandchild_1;