text-splitters[patch]: fix HTMLSectionSplitter parsing of xslt paths (#22176)

## Description
This PR allows passing the HTMLSectionSplitter paths to xslt files. It
does so by fixing two trivial bugs with how passed paths were being
handled. It also changes the default value of the param `xslt_path` to
`None` so the special case where the file was part of the langchain
package could be handled.

## Issue
#22175
This commit is contained in:
Tom Clelford 2024-06-03 21:26:59 +01:00 committed by GitHub
parent 01352bb55f
commit c599732e1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 10 deletions

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import os
import pathlib import pathlib
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, cast from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, cast
@ -173,7 +172,7 @@ class HTMLSectionSplitter:
def __init__( def __init__(
self, self,
headers_to_split_on: List[Tuple[str, str]], headers_to_split_on: List[Tuple[str, str]],
xslt_path: str = "xsl/converting_to_header.xslt", xslt_path: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create a new HTMLSectionSplitter. """Create a new HTMLSectionSplitter.
@ -183,10 +182,17 @@ class HTMLSectionSplitter:
(arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2"]. h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2"].
xslt_path: path to xslt file for document transformation. xslt_path: path to xslt file for document transformation.
Uses a default if not passed.
Needed for html contents that using different format and layouts. Needed for html contents that using different format and layouts.
""" """
self.headers_to_split_on = dict(headers_to_split_on) self.headers_to_split_on = dict(headers_to_split_on)
self.xslt_path = xslt_path
if xslt_path is None:
self.xslt_path = (
pathlib.Path(__file__).parent / "xsl/converting_to_header.xslt"
).absolute()
else:
self.xslt_path = pathlib.Path(xslt_path).absolute()
self.kwargs = kwargs self.kwargs = kwargs
def split_documents(self, documents: Iterable[Document]) -> List[Document]: def split_documents(self, documents: Iterable[Document]) -> List[Document]:
@ -284,13 +290,7 @@ class HTMLSectionSplitter:
parser = etree.HTMLParser() parser = etree.HTMLParser()
tree = etree.parse(StringIO(html_content), parser) tree = etree.parse(StringIO(html_content), parser)
# document transformation for "structure-aware" chunking is handled with xsl. xslt_tree = etree.parse(self.xslt_path)
# this is needed for htmls files that using different font sizes and layouts
# check to see if self.xslt_path is a relative path or absolute path
if not os.path.isabs(self.xslt_path):
xslt_path = pathlib.Path(__file__).parent / self.xslt_path
xslt_tree = etree.parse(xslt_path)
transform = etree.XSLT(xslt_tree) transform = etree.XSLT(xslt_tree)
result = transform(tree) result = transform(tree)
return str(result) return str(result)

View File

@ -0,0 +1,9 @@
<?xml version="1.0"?>
<xsl:stylesheet version="1.0"
xmlns:xsl="http://www.w3.org/1999/XSL/Transform">
<xsl:template match="node()|@*">
<xsl:copy>
<xsl:apply-templates select="node()|@*" />
</xsl:copy>
</xsl:template>
</xsl:stylesheet>

View File

@ -1619,6 +1619,37 @@ def test_happy_path_splitting_based_on_header_with_whitespace_chars() -> None:
assert docs[2].metadata["Header 2"] == "Baz" assert docs[2].metadata["Header 2"] == "Baz"
@pytest.mark.requires("lxml")
@pytest.mark.requires("bs4")
def test_section_splitter_accepts_a_relative_path() -> None:
html_string = """<html><body><p>Foo</p></body></html>"""
test_file = Path("tests/test_data/test_splitter.xslt")
assert test_file.is_file()
sec_splitter = HTMLSectionSplitter(
headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")],
xslt_path=test_file.as_posix(),
)
sec_splitter.split_text(html_string)
@pytest.mark.requires("lxml")
@pytest.mark.requires("bs4")
def test_section_splitter_accepts_an_absolute_path() -> None:
html_string = """<html><body><p>Foo</p></body></html>"""
test_file = Path("tests/test_data/test_splitter.xslt").absolute()
assert test_file.is_absolute()
assert test_file.is_file()
sec_splitter = HTMLSectionSplitter(
headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")],
xslt_path=test_file.as_posix(),
)
sec_splitter.split_text(html_string)
def test_split_json() -> None: def test_split_json() -> None:
"""Test json text splitter""" """Test json text splitter"""
max_chunk = 800 max_chunk = 800