From 727d5023ce88e18e3074ef620a98137d26ff92a3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 25 Mar 2024 16:21:52 -0400 Subject: [PATCH] core[patch]: Use defusedxml in XMLOutputParser (#19526) This mitigates a security concern for users still using older versions of libexpat that causes an attacker to compromise the availability of the system if an attacker manages to surface malicious payload to this XMLParser. --- .../core/langchain_core/output_parsers/xml.py | 106 ++++++++++++------ libs/core/poetry.lock | 48 ++++---- libs/core/pyproject.toml | 1 + .../output_parsers/test_xml_parser.py | 59 +++++++++- 4 files changed, 150 insertions(+), 64 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index f74bd4c1050..9871ecee91d 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,7 @@ import re -import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import TreeBuilder from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -35,6 +36,10 @@ class XMLOutputParser(BaseTransformOutputParser): return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml import ElementTree as DET # type: ignore[import] + # Try to find XML string within triple backticks match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: @@ -46,18 +51,24 @@ class XMLOutputParser(BaseTransformOutputParser): text = text.strip() try: - root = ET.fromstring(text) + root = DET.fromstring(text) return self._root_to_dict(root) - except ET.ParseError as e: + except (DET.ParseError, DET.EntitiesForbidden) as e: msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import] + + parser = ET.XMLPullParser( + ["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder()) + ) xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) xml_started = False current_path: List[str] = [] current_path_has_children = False @@ -83,6 +94,61 @@ class XMLOutputParser(BaseTransformOutputParser): parser.feed(buffer) buffer = "" # yield all events + + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + if current_path: + current_path_has_children = True + else: + xml_started = False + # close parser + parser.close() + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[AddableDict]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import] + + _parser = DefusedXMLParser(target=TreeBuilder()) + parser = ET.XMLPullParser(["start", "end"], _parser=_parser) + xml_start_re = re.compile(r"<[a-zA-Z:_]") + xml_started = False + current_path: List[str] = [] + current_path_has_children = False + buffer = "" + async for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # add chunk to buffer of unprocessed text + buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not xml_started: + if match := xml_start_re.search(buffer): + # if xml string has started, remove all text before it + buffer = buffer[match.start() :] + xml_started = True + else: + continue + # feed buffer to parser + parser.feed(buffer) + buffer = "" + # yield all events for event, elem in parser.read_events(): if event == "start": # update current path @@ -102,38 +168,6 @@ class XMLOutputParser(BaseTransformOutputParser): # close parser parser.close() - async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] - ) -> AsyncIterator[AddableDict]: - parser = ET.XMLPullParser(["start", "end"]) - current_path: List[str] = [] - current_path_has_children = False - async for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # pass chunk to parser - parser.feed(chunk) - # yield all events - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - current_path_has_children = True - # close parser - parser.close() - def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" result: Dict[str, List[Any]] = {root.tag: []} diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 4d9ab6daf26..997d7416722 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -660,13 +660,13 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -675,17 +675,17 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" -version = "6.3.1" +version = "6.4.0" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.3.1-py3-none-any.whl", hash = "sha256:4811639ca7fa830abdb8e9ca0a104dc6ad13de691d9fe0d3173a71304f068159"}, - {file = "importlib_resources-6.3.1.tar.gz", hash = "sha256:29a3d16556e330c3c8fb8202118c5ff41241cc34cbfb25989bbad226d99b7995"}, + {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, + {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, ] [package.dependencies] @@ -693,7 +693,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -1020,13 +1020,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "jupyter-events" -version = "0.9.1" +version = "0.10.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_events-0.9.1-py3-none-any.whl", hash = "sha256:e51f43d2c25c2ddf02d7f7a5045f71fc1d5cb5ad04ef6db20da961c077654b9b"}, - {file = "jupyter_events-0.9.1.tar.gz", hash = "sha256:a52e86f59eb317ee71ff2d7500c94b963b8a24f0b7a1517e2e653e24258e15c7"}, + {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"}, + {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"}, ] [package.dependencies] @@ -1216,13 +1216,13 @@ url = "../text-splitters" [[package]] name = "langsmith" -version = "0.1.27" +version = "0.1.31" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.27-py3-none-any.whl", hash = "sha256:d223176952b1525c958189ab1b894f5bd9891ec9177222f7a978aeee4bf1cc95"}, - {file = "langsmith-0.1.27.tar.gz", hash = "sha256:e0a339d976362051adf3fdbc43fcc7c00bb4615a401321ad7e556bd2dab556c0"}, + {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"}, + {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"}, ] [package.dependencies] @@ -1406,13 +1406,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.16.2" +version = "7.16.3" description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.16.2-py3-none-any.whl", hash = "sha256:0c01c23981a8de0220255706822c40b751438e32467d6a686e26be08ba784382"}, - {file = "nbconvert-7.16.2.tar.gz", hash = "sha256:8310edd41e1c43947e4ecf16614c61469ebc024898eb808cce0999860fc9fb16"}, + {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"}, + {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"}, ] [package.dependencies] @@ -1439,7 +1439,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"] webpdf = ["playwright"] [[package]] @@ -1997,17 +1997,17 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy [[package]] name = "pytest-mock" -version = "3.12.0" +version = "3.14.0" description = "Thin-wrapper around the mock package for easier use with pytest" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, - {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] [package.dependencies] -pytest = ">=5.0" +pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] @@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ca611429e3dd84ce6dac7ef69d7d9b4da78bf467356946e37016b821e5fe752e" +content-hash = "a13a0a8454b242106bb681fa74e1f1320a0198f2e07b35d29d985b03a310cf67" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 7a0bf312fe0..ca9465bff41 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -19,6 +19,7 @@ PyYAML = ">=5.3" requests = "^2" packaging = "^23.2" jinja2 = { version = "^3", optional = true } +defusedxml = "^0.7" [tool.poetry.group.lint] optional = true diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 65b095f308e..17ef1a558e6 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,4 +1,7 @@ """Test XMLOutputParser""" +from typing import AsyncIterator +from xml.etree.ElementTree import ParseError + import pytest from langchain_core.exceptions import OutputParserException @@ -40,19 +43,29 @@ More random text """, ], ) -def test_xml_output_parser(result: str) -> None: +async def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser.""" xml_parser = XMLOutputParser() - - xml_result = xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == xml_result + assert DEF_RESULT_EXPECTED == xml_parser.parse(result) + assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result)) assert list(xml_parser.transform(iter(result))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, ] + async def _as_iter(string: str) -> AsyncIterator[str]: + for c in string: + yield c + + chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] + assert chunks == [ + {"foo": [{"bar": [{"baz": None}]}]}, + {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, + {"foo": [{"baz": "tag"}]}, + ] + @pytest.mark.parametrize("result", ["foo>", " None: @@ -63,3 +76,41 @@ def test_xml_output_parser_fail(result: str) -> None: with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) + + +MALICIOUS_XML = """ + + + + + + + + + + +]> +&lol9;""" + + +async def tests_billion_laughs_attack() -> None: + parser = XMLOutputParser() + with pytest.raises(OutputParserException): + parser.parse(MALICIOUS_XML) + + with pytest.raises(OutputParserException): + await parser.aparse(MALICIOUS_XML) + + with pytest.raises(ParseError): + # Right now raises undefined entity error + assert list(parser.transform(iter(MALICIOUS_XML))) == [ + {"foo": [{"bar": [{"baz": None}]}]} + ] + + async def _as_iter(string: str) -> AsyncIterator[str]: + for c in string: + yield c + + with pytest.raises(ParseError): + chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))] + assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]