diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index 45b8daa9699..6000cd5ebec 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -389,16 +389,23 @@ class MarkdownHeaderTextSplitter: initial_metadata: Dict[str, str] = {} in_code_block = False + opening_fence = "" for line in lines: stripped_line = line.strip() - if stripped_line.startswith("```"): - # code block in one row - if stripped_line.count("```") >= 2: + if not in_code_block: + # Exclude inline code spans + if stripped_line.startswith("```") and stripped_line.count("```") == 1: + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): in_code_block = False - else: - in_code_block = not in_code_block + opening_fence = "" if in_code_block: current_content.append(stripped_line) diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index 8edf76892b0..7bb1d97fba5 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -1031,6 +1031,77 @@ def test_md_header_text_splitter_3() -> None: assert output == expected_output +@pytest.mark.parametrize("fence", [("```"), ("~~~")]) +def test_md_header_text_splitter_fenced_code_block(fence: str) -> None: + """Test markdown splitter by header: Fenced code block.""" + + markdown_document = ( + "# This is a Header\n\n" + f"{fence}\n" + "foo()\n" + "# Not a header\n" + "bar()\n" + f"{fence}" + ) + + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ] + + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + ) + output = markdown_splitter.split_text(markdown_document) + + expected_output = [ + Document( + page_content=f"{fence}\nfoo()\n# Not a header\nbar()\n{fence}", + metadata={"Header 1": "This is a Header"}, + ), + ] + + assert output == expected_output + + +@pytest.mark.parametrize(["fence", "other_fence"], [("```", "~~~"), ("~~~", "```")]) +def test_md_header_text_splitter_fenced_code_block_interleaved( + fence: str, other_fence: str +) -> None: + """Test markdown splitter by header: Interleaved fenced code block.""" + + markdown_document = ( + "# This is a Header\n\n" + f"{fence}\n" + "foo\n" + "# Not a header\n" + f"{other_fence}\n" + "# Not a header\n" + f"{fence}" + ) + + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ] + + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + ) + output = markdown_splitter.split_text(markdown_document) + + expected_output = [ + Document( + page_content=( + f"{fence}\nfoo\n# Not a header\n{other_fence}\n# Not a header\n{fence}" + ), + metadata={"Header 1": "This is a Header"}, + ), + ] + + assert output == expected_output + + def test_solidity_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.SOL, chunk_size=CHUNK_SIZE, chunk_overlap=0