diff --git a/libs/text-splitters/langchain_text_splitters/html.py b/libs/text-splitters/langchain_text_splitters/html.py index a917f6a9c34..06026bf31c1 100644 --- a/libs/text-splitters/langchain_text_splitters/html.py +++ b/libs/text-splitters/langchain_text_splitters/html.py @@ -309,7 +309,6 @@ class HTMLSectionSplitter: def __init__( self, headers_to_split_on: List[Tuple[str, str]], - xslt_path: Optional[str] = None, **kwargs: Any, ) -> None: """Create a new HTMLSectionSplitter. @@ -318,20 +317,13 @@ class HTMLSectionSplitter: headers_to_split_on: list of tuples of headers we want to track mapped to (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2"]. - xslt_path: path to xslt file for document transformation. - Uses a default if not passed. - Needed for html contents that using different format and layouts. **kwargs (Any): Additional optional arguments for customizations. """ self.headers_to_split_on = dict(headers_to_split_on) - - if xslt_path is None: - self.xslt_path = ( - pathlib.Path(__file__).parent / "xsl/converting_to_header.xslt" - ).absolute() - else: - self.xslt_path = pathlib.Path(xslt_path).absolute() + self.xslt_path = ( + pathlib.Path(__file__).parent / "xsl/converting_to_header.xslt" + ).absolute() self.kwargs = kwargs def split_documents(self, documents: Iterable[Document]) -> List[Document]: @@ -457,11 +449,20 @@ class HTMLSectionSplitter: "Unable to import lxml, please install with `pip install lxml`." ) from e # use lxml library to parse html document and return xml ElementTree - parser = etree.HTMLParser() - tree = etree.parse(StringIO(html_content), parser) + # Create secure parsers to prevent XXE attacks + html_parser = etree.HTMLParser(no_network=True) + xslt_parser = etree.XMLParser( + resolve_entities=False, no_network=True, load_dtd=False + ) - xslt_tree = etree.parse(self.xslt_path) - transform = etree.XSLT(xslt_tree) + # Apply XSLT access control to prevent file/network access + # DENY_ALL is a predefined access control that blocks all file/network access + # Type ignore needed due to incomplete lxml type stubs + ac = etree.XSLTAccessControl.DENY_ALL # type: ignore[attr-defined] + + tree = etree.parse(StringIO(html_content), html_parser) + xslt_tree = etree.parse(self.xslt_path, xslt_parser) + transform = etree.XSLT(xslt_tree, access_control=ac) result = transform(tree) return str(result) diff --git a/libs/text-splitters/tests/unit_tests/test_html_security.py b/libs/text-splitters/tests/unit_tests/test_html_security.py new file mode 100644 index 00000000000..794c9d91b20 --- /dev/null +++ b/libs/text-splitters/tests/unit_tests/test_html_security.py @@ -0,0 +1,130 @@ +"""Security tests for HTML splitters to prevent XXE attacks.""" + +import pytest + +from langchain_text_splitters.html import HTMLSectionSplitter + + +@pytest.mark.requires("lxml", "bs4") +class TestHTMLSectionSplitterSecurity: + """Security tests for HTMLSectionSplitter to ensure XXE prevention.""" + + def test_xxe_entity_attack_blocked(self) -> None: + """Test that external entity attacks are blocked.""" + # Create HTML content to process + html_content = """

Test content

""" + + # Since xslt_path parameter is removed, this attack vector is eliminated + # The splitter should use only the default XSLT + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # Process the HTML - should not contain any external entity content + result = splitter.split_text(html_content) + + # Verify that no external entity content is present + all_content = " ".join([doc.page_content for doc in result]) + assert "root:" not in all_content # /etc/passwd content + assert "XXE Attack Result" not in all_content + + def test_xxe_document_function_blocked(self) -> None: + """Test that XSLT document() function attacks are blocked.""" + # Even if someone modifies the default XSLT internally, + # the secure parser configuration should block document() attacks + + html_content = ( + """

Test Header

Test content

""" + ) + + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # Process the HTML safely + result = splitter.split_text(html_content) + + # Should process normally without any security issues + assert len(result) > 0 + assert any("Test content" in doc.page_content for doc in result) + + def test_secure_parser_configuration(self) -> None: + """Test that parsers are configured with security settings.""" + # This test verifies our security hardening is in place + html_content = """

Test

""" + + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # The convert_possible_tags_to_header method should use secure parsers + result = splitter.convert_possible_tags_to_header(html_content) + + # Result should be valid transformed HTML + assert result is not None + assert isinstance(result, str) + + def test_no_network_access(self) -> None: + """Test that network access is blocked in parsers.""" + # Create HTML that might trigger network access + html_with_external_ref = """ + +]> + + +

Test

+

&external;

+ +""" + + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # Process the HTML - should not make network requests + result = splitter.split_text(html_with_external_ref) + + # Verify no external content is included + all_content = " ".join([doc.page_content for doc in result]) + assert "attacker.com" not in all_content + + def test_dtd_processing_disabled(self) -> None: + """Test that DTD processing is disabled.""" + # HTML with DTD that attempts to define entities + html_with_dtd = """ + + + + +]> + + +

Header

+

&test;

+ +""" + + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # Process the HTML - entities should not be resolved + result = splitter.split_text(html_with_dtd) + + # The entity should not be expanded + all_content = " ".join([doc.page_content for doc in result]) + assert "This is a test entity" not in all_content + + def test_safe_default_xslt_usage(self) -> None: + """Test that the default XSLT file is used safely.""" + # Test with HTML that has font-size styling (what the default XSLT handles) + html_with_font_size = """ + + Large Header +

Content under large text

+ Small Header +

Content under small text

+ +""" + + splitter = HTMLSectionSplitter(headers_to_split_on=[("h1", "Header 1")]) + + # Process the HTML using the default XSLT + result = splitter.split_text(html_with_font_size) + + # Should successfully process the content + assert len(result) > 0 + # Large font text should be converted to header + assert any("Large Header" in str(doc.metadata.values()) for doc in result) diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 9aeb4b1520a..f66fbcdd306 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -3,7 +3,6 @@ import random import re import string -from pathlib import Path from typing import Any, Callable, List, Tuple import pytest @@ -2865,37 +2864,6 @@ def test_happy_path_splitting_based_on_header_with_whitespace_chars() -> None: assert docs[2].metadata["Header 2"] == "Baz" -@pytest.mark.requires("bs4") -@pytest.mark.requires("lxml") -def test_section_splitter_accepts_a_relative_path() -> None: - html_string = """

Foo

""" - test_file = Path("tests/test_data/test_splitter.xslt") - assert test_file.is_file() - - sec_splitter = HTMLSectionSplitter( - headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")], - xslt_path=test_file.as_posix(), - ) - - sec_splitter.split_text(html_string) - - -@pytest.mark.requires("bs4") -@pytest.mark.requires("lxml") -def test_section_splitter_accepts_an_absolute_path() -> None: - html_string = """

Foo

""" - test_file = Path("tests/test_data/test_splitter.xslt").absolute() - assert test_file.is_absolute() - assert test_file.is_file() - - sec_splitter = HTMLSectionSplitter( - headers_to_split_on=[("h1", "Header 1"), ("h2", "Header 2")], - xslt_path=test_file.as_posix(), - ) - - sec_splitter.split_text(html_string) - - @pytest.mark.requires("bs4") @pytest.mark.requires("lxml") def test_happy_path_splitting_with_duplicate_header_tag() -> None: