mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
core[minor]: update draw_mermaid node label processing (#23285)
This fixes processing issue for nodes with numbers in their labels (e.g. `"node_1"`, which would previously be relabeled as `"node__"`, and now are correctly processed as `"node_1"`)
This commit is contained in:
parent
7ee2822ec2
commit
9ac302cb97
@ -114,7 +114,7 @@ def draw_mermaid(
|
||||
|
||||
def _escape_node_label(node_label: str) -> str:
|
||||
"""Escapes the node label for Mermaid syntax."""
|
||||
return re.sub(r"[^a-zA-Z-_]", "_", node_label)
|
||||
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
||||
|
||||
|
||||
def _adjust_mermaid_edge(
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
@ -734,3 +735,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
||||
assert runnable.invoke(1) == 1
|
||||
# check whether runnable.get_graph works
|
||||
runnable.get_graph()
|
||||
|
||||
|
||||
def test_graph_mermaid_escape_node_label() -> None:
|
||||
"""Test that node labels are correctly preprocessed for draw_mermaid"""
|
||||
assert _escape_node_label("foo") == "foo"
|
||||
assert _escape_node_label("foo-bar") == "foo-bar"
|
||||
assert _escape_node_label("foo_1") == "foo_1"
|
||||
assert _escape_node_label("#foo*&!") == "_foo___"
|
||||
|
Loading…
Reference in New Issue
Block a user