From c599732e1add8d86f887e47d83a8a3e831fdb4e7 Mon Sep 17 00:00:00 2001 From: Tom Clelford Date: Mon, 3 Jun 2024 21:26:59 +0100 Subject: [PATCH] text-splitters[patch]: fix HTMLSectionSplitter parsing of xslt paths (#22176) ## Description This PR allows passing the HTMLSectionSplitter paths to xslt files. It does so by fixing two trivial bugs with how passed paths were being handled. It also changes the default value of the param `xslt_path` to `None` so the special case where the file was part of the langchain package could be handled. ## Issue #22175 --- .../langchain_text_splitters/html.py | 20 ++++++------ .../tests/test_data/test_splitter.xslt | 9 ++++++ .../tests/unit_tests/test_text_splitters.py | 31 +++++++++++++++++++ 3 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 libs/text-splitters/tests/test_data/test_splitter.xslt diff --git a/libs/text-splitters/langchain_text_splitters/html.py b/libs/text-splitters/langchain_text_splitters/html.py index 6ad27314c02..89113313967 100644 --- a/libs/text-splitters/langchain_text_splitters/html.py +++ b/libs/text-splitters/langchain_text_splitters/html.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import os import pathlib from io import BytesIO, StringIO from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, cast @@ -173,7 +172,7 @@ class HTMLSectionSplitter: def __init__( self, headers_to_split_on: List[Tuple[str, str]], - xslt_path: str = "xsl/converting_to_header.xslt", + xslt_path: Optional[str] = None, **kwargs: Any, ) -> None: """Create a new HTMLSectionSplitter. @@ -183,10 +182,17 @@ class HTMLSectionSplitter: (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2"]. xslt_path: path to xslt file for document transformation. + Uses a default if not passed. Needed for html contents that using different format and layouts. """ self.headers_to_split_on = dict(headers_to_split_on) - self.xslt_path = xslt_path + + if xslt_path is None: + self.xslt_path = ( + pathlib.Path(__file__).parent / "xsl/converting_to_header.xslt" + ).absolute() + else: + self.xslt_path = pathlib.Path(xslt_path).absolute() self.kwargs = kwargs def split_documents(self, documents: Iterable[Document]) -> List[Document]: @@ -284,13 +290,7 @@ class HTMLSectionSplitter: parser = etree.HTMLParser() tree = etree.parse(StringIO(html_content), parser) - # document transformation for "structure-aware" chunking is handled with xsl. - # this is needed for htmls files that using different font sizes and layouts - # check to see if self.xslt_path is a relative path or absolute path - if not os.path.isabs(self.xslt_path): - xslt_path = pathlib.Path(__file__).parent / self.xslt_path - - xslt_tree = etree.parse(xslt_path) + xslt_tree = etree.parse(self.xslt_path) transform = etree.XSLT(xslt_tree) result = transform(tree) return str(result) diff --git a/libs/text-splitters/tests/test_data/test_splitter.xslt b/libs/text-splitters/tests/test_data/test_splitter.xslt new file mode 100644 index 00000000000..cbb5828bf12 --- /dev/null +++ b/libs/text-splitters/tests/test_data/test_splitter.xslt @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 062f4d089d1..9f9c76b98d7 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -1619,6 +1619,37 @@ def test_happy_path_splitting_based_on_header_with_whitespace_chars() -> None: assert docs[2].metadata["Header 2"] == "Baz" +@pytest.mark.requires("lxml") +@pytest.mark.requires("bs4") +def test_section_splitter_accepts_a_relative_path() -> None: + html_string = """

Foo

""" + test_file = Path("tests/test_data/test_splitter.xslt") + assert test_file.is_file() + + sec_splitter = HTMLSectionSplitter( + headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")], + xslt_path=test_file.as_posix(), + ) + + sec_splitter.split_text(html_string) + + +@pytest.mark.requires("lxml") +@pytest.mark.requires("bs4") +def test_section_splitter_accepts_an_absolute_path() -> None: + html_string = """

Foo

""" + test_file = Path("tests/test_data/test_splitter.xslt").absolute() + assert test_file.is_absolute() + assert test_file.is_file() + + sec_splitter = HTMLSectionSplitter( + headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")], + xslt_path=test_file.as_posix(), + ) + + sec_splitter.split_text(html_string) + + def test_split_json() -> None: """Test json text splitter""" max_chunk = 800