From 2703a1b061d704bb47840a9e564b8d45dd4d5956 Mon Sep 17 00:00:00 2001 From: unifyh <18213435+unifyh@users.noreply.github.com> Date: Wed, 29 Nov 2023 00:52:38 +0800 Subject: [PATCH] Fix `MarkdownHeaderTextSplitter` not recognizing tilde-fenced code blocks (#13511) - **Description:** Previously `MarkdownHeaderTextSplitter` did not consider tilde-fenced code blocks (https://spec.commonmark.org/0.30/#fenced-code-blocks). This PR fixes that. ````md # Bug caused by previous implementation: ~~~py foo() # This is a comment that would be considered header bar() ~~~ ```` - **Tag maintainer:** @baskaryan --- libs/langchain/langchain/text_splitter.py | 17 +++-- .../tests/unit_tests/test_text_splitter.py | 71 +++++++++++++++++++ 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index 45b8daa9699..6000cd5ebec 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -389,16 +389,23 @@ class MarkdownHeaderTextSplitter: initial_metadata: Dict[str, str] = {} in_code_block = False + opening_fence = "" for line in lines: stripped_line = line.strip() - if stripped_line.startswith("```"): - # code block in one row - if stripped_line.count("```") >= 2: + if not in_code_block: + # Exclude inline code spans + if stripped_line.startswith("```") and stripped_line.count("```") == 1: + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): in_code_block = False - else: - in_code_block = not in_code_block + opening_fence = "" if in_code_block: current_content.append(stripped_line) diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index 8edf76892b0..7bb1d97fba5 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -1031,6 +1031,77 @@ def test_md_header_text_splitter_3() -> None: assert output == expected_output +@pytest.mark.parametrize("fence", [("```"), ("~~~")]) +def test_md_header_text_splitter_fenced_code_block(fence: str) -> None: + """Test markdown splitter by header: Fenced code block.""" + + markdown_document = ( + "# This is a Header\n\n" + f"{fence}\n" + "foo()\n" + "# Not a header\n" + "bar()\n" + f"{fence}" + ) + + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ] + + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + ) + output = markdown_splitter.split_text(markdown_document) + + expected_output = [ + Document( + page_content=f"{fence}\nfoo()\n# Not a header\nbar()\n{fence}", + metadata={"Header 1": "This is a Header"}, + ), + ] + + assert output == expected_output + + +@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")]) +def test_md_header_text_splitter_fenced_code_block_interleaved( + fence: str, other_fence: str +) -> None: + """Test markdown splitter by header: Interleaved fenced code block.""" + + markdown_document = ( + "# This is a Header\n\n" + f"{fence}\n" + "foo\n" + "# Not a header\n" + f"{other_fence}\n" + "# Not a header\n" + f"{fence}" + ) + + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ] + + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + ) + output = markdown_splitter.split_text(markdown_document) + + expected_output = [ + Document( + page_content=( + f"{fence}\nfoo\n# Not a header\n{other_fence}\n# Not a header\n{fence}" + ), + metadata={"Header 1": "This is a Header"}, + ), + ] + + assert output == expected_output + + def test_solidity_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.SOL, chunk_size=CHUNK_SIZE, chunk_overlap=0