diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index f74bd4c1050..9871ecee91d 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,7 @@ import re -import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import TreeBuilder from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -35,6 +36,10 @@ class XMLOutputParser(BaseTransformOutputParser): return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml import ElementTree as DET # type: ignore[import] + # Try to find XML string within triple backticks match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: @@ -46,18 +51,24 @@ class XMLOutputParser(BaseTransformOutputParser): text = text.strip() try: - root = ET.fromstring(text) + root = DET.fromstring(text) return self._root_to_dict(root) - except ET.ParseError as e: + except (DET.ParseError, DET.EntitiesForbidden) as e: msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import] + + parser = ET.XMLPullParser( + ["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder()) + ) xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) xml_started = False current_path: List[str] = [] current_path_has_children = False @@ -83,6 +94,61 @@ class XMLOutputParser(BaseTransformOutputParser): parser.feed(buffer) buffer = "" # yield all events + + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + if current_path: + current_path_has_children = True + else: + xml_started = False + # close parser + parser.close() + + async def _atransform( + self, input: AsyncIterator[Union[str, BaseMessage]] + ) -> AsyncIterator[AddableDict]: + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import] + + _parser = DefusedXMLParser(target=TreeBuilder()) + parser = ET.XMLPullParser(["start", "end"], _parser=_parser) + xml_start_re = re.compile(r"<[a-zA-Z:_]") + xml_started = False + current_path: List[str] = [] + current_path_has_children = False + buffer = "" + async for chunk in input: + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + continue + chunk = chunk_content + # add chunk to buffer of unprocessed text + buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not xml_started: + if match := xml_start_re.search(buffer): + # if xml string has started, remove all text before it + buffer = buffer[match.start() :] + xml_started = True + else: + continue + # feed buffer to parser + parser.feed(buffer) + buffer = "" + # yield all events for event, elem in parser.read_events(): if event == "start": # update current path @@ -102,38 +168,6 @@ class XMLOutputParser(BaseTransformOutputParser): # close parser parser.close() - async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] - ) -> AsyncIterator[AddableDict]: - parser = ET.XMLPullParser(["start", "end"]) - current_path: List[str] = [] - current_path_has_children = False - async for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # pass chunk to parser - parser.feed(chunk) - # yield all events - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - current_path_has_children = True - # close parser - parser.close() - def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" result: Dict[str, List[Any]] = {root.tag: []} diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 4d9ab6daf26..997d7416722 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -660,13 +660,13 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -675,17 +675,17 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" -version = "6.3.1" +version = "6.4.0" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.3.1-py3-none-any.whl", hash = "sha256:4811639ca7fa830abdb8e9ca0a104dc6ad13de691d9fe0d3173a71304f068159"}, - {file = "importlib_resources-6.3.1.tar.gz", hash = "sha256:29a3d16556e330c3c8fb8202118c5ff41241cc34cbfb25989bbad226d99b7995"}, + {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, + {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, ] [package.dependencies] @@ -693,7 +693,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -1020,13 +1020,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "jupyter-events" -version = "0.9.1" +version = "0.10.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_events-0.9.1-py3-none-any.whl", hash = "sha256:e51f43d2c25c2ddf02d7f7a5045f71fc1d5cb5ad04ef6db20da961c077654b9b"}, - {file = "jupyter_events-0.9.1.tar.gz", hash = "sha256:a52e86f59eb317ee71ff2d7500c94b963b8a24f0b7a1517e2e653e24258e15c7"}, + {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"}, + {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"}, ] [package.dependencies] @@ -1216,13 +1216,13 @@ url = "../text-splitters" [[package]] name = "langsmith" -version = "0.1.27" +version = "0.1.31" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.27-py3-none-any.whl", hash = "sha256:d223176952b1525c958189ab1b894f5bd9891ec9177222f7a978aeee4bf1cc95"}, - {file = "langsmith-0.1.27.tar.gz", hash = "sha256:e0a339d976362051adf3fdbc43fcc7c00bb4615a401321ad7e556bd2dab556c0"}, + {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"}, + {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"}, ] [package.dependencies] @@ -1406,13 +1406,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.16.2" +version = "7.16.3" description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.16.2-py3-none-any.whl", hash = "sha256:0c01c23981a8de0220255706822c40b751438e32467d6a686e26be08ba784382"}, - {file = "nbconvert-7.16.2.tar.gz", hash = "sha256:8310edd41e1c43947e4ecf16614c61469ebc024898eb808cce0999860fc9fb16"}, + {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"}, + {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"}, ] [package.dependencies] @@ -1439,7 +1439,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"] webpdf = ["playwright"] [[package]] @@ -1997,17 +1997,17 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy [[package]] name = "pytest-mock" -version = "3.12.0" +version = "3.14.0" description = "Thin-wrapper around the mock package for easier use with pytest" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, - {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] [package.dependencies] -pytest = ">=5.0" +pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] @@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ca611429e3dd84ce6dac7ef69d7d9b4da78bf467356946e37016b821e5fe752e" +content-hash = "a13a0a8454b242106bb681fa74e1f1320a0198f2e07b35d29d985b03a310cf67" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 7a0bf312fe0..ca9465bff41 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -19,6 +19,7 @@ PyYAML = ">=5.3" requests = "^2" packaging = "^23.2" jinja2 = { version = "^3", optional = true } +defusedxml = "^0.7" [tool.poetry.group.lint] optional = true diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 65b095f308e..17ef1a558e6 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,4 +1,7 @@ """Test XMLOutputParser""" +from typing import AsyncIterator +from xml.etree.ElementTree import ParseError + import pytest from langchain_core.exceptions import OutputParserException @@ -40,19 +43,29 @@ More random text """, ], ) -def test_xml_output_parser(result: str) -> None: +async def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser.""" xml_parser = XMLOutputParser() - - xml_result = xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == xml_result + assert DEF_RESULT_EXPECTED == xml_parser.parse(result) + assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result)) assert list(xml_parser.transform(iter(result))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, ] + async def _as_iter(string: str) -> AsyncIterator[str]: + for c in string: + yield c + + chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] + assert chunks == [ + {"foo": [{"bar": [{"baz": None}]}]}, + {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, + {"foo": [{"baz": "tag"}]}, + ] + @pytest.mark.parametrize("result", ["foo>", " None: @@ -63,3 +76,41 @@ def test_xml_output_parser_fail(result: str) -> None: with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) + + +MALICIOUS_XML = """ + + + + + + + + + + +]> +&lol9;""" + + +async def tests_billion_laughs_attack() -> None: + parser = XMLOutputParser() + with pytest.raises(OutputParserException): + parser.parse(MALICIOUS_XML) + + with pytest.raises(OutputParserException): + await parser.aparse(MALICIOUS_XML) + + with pytest.raises(ParseError): + # Right now raises undefined entity error + assert list(parser.transform(iter(MALICIOUS_XML))) == [ + {"foo": [{"bar": [{"baz": None}]}]} + ] + + async def _as_iter(string: str) -> AsyncIterator[str]: + for c in string: + yield c + + with pytest.raises(ParseError): + chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))] + assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]