core[patch]: Use defusedxml in XMLOutputParser (#19526)

This mitigates a security concern for users still using older versions of libexpat that causes an attacker to compromise the availability of the system if an attacker manages to surface malicious payload to this XMLParser.
This commit is contained in:
Eugene Yurtsev 2024-03-25 16:21:52 -04:00 committed by GitHub
parent e1a6341940
commit 727d5023ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 150 additions and 64 deletions

View File

@ -1,6 +1,7 @@
import re import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET
from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -35,6 +36,10 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]: def parse(self, text: str) -> Dict[str, List[Any]]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]
# Try to find XML string within triple backticks # Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None: if match is not None:
@ -46,18 +51,24 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip() text = text.strip()
try: try:
root = ET.fromstring(text) root = DET.fromstring(text)
return self._root_to_dict(root) return self._root_to_dict(root)
except ET.ParseError as e: except (DET.ParseError, DET.EntitiesForbidden) as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}" msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e raise OutputParserException(msg, llm_output=text) from e
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
parser = ET.XMLPullParser(
["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder())
)
xml_start_re = re.compile(r"<[a-zA-Z:_]") xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False xml_started = False
current_path: List[str] = [] current_path: List[str] = []
current_path_has_children = False current_path_has_children = False
@ -83,6 +94,61 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer) parser.feed(buffer)
buffer = "" buffer = ""
# yield all events # yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
# close parser
parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
_parser = DefusedXMLParser(target=TreeBuilder())
parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events
for event, elem in parser.read_events(): for event, elem in parser.read_events():
if event == "start": if event == "start":
# update current path # update current path
@ -102,38 +168,6 @@ class XMLOutputParser(BaseTransformOutputParser):
# close parser # close parser
parser.close() parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary.""" """Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []} result: Dict[str, List[Any]] = {root.tag: []}

48
libs/core/poetry.lock generated
View File

@ -660,13 +660,13 @@ files = [
[[package]] [[package]]
name = "importlib-metadata" name = "importlib-metadata"
version = "7.0.2" version = "7.1.0"
description = "Read metadata from Python packages" description = "Read metadata from Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"},
{file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"},
] ]
[package.dependencies] [package.dependencies]
@ -675,17 +675,17 @@ zipp = ">=0.5"
[package.extras] [package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"] perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]] [[package]]
name = "importlib-resources" name = "importlib-resources"
version = "6.3.1" version = "6.4.0"
description = "Read resources from Python packages" description = "Read resources from Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "importlib_resources-6.3.1-py3-none-any.whl", hash = "sha256:4811639ca7fa830abdb8e9ca0a104dc6ad13de691d9fe0d3173a71304f068159"}, {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"},
{file = "importlib_resources-6.3.1.tar.gz", hash = "sha256:29a3d16556e330c3c8fb8202118c5ff41241cc34cbfb25989bbad226d99b7995"}, {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"},
] ]
[package.dependencies] [package.dependencies]
@ -693,7 +693,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
[package.extras] [package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
@ -1020,13 +1020,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"
[[package]] [[package]]
name = "jupyter-events" name = "jupyter-events"
version = "0.9.1" version = "0.10.0"
description = "Jupyter Event System library" description = "Jupyter Event System library"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "jupyter_events-0.9.1-py3-none-any.whl", hash = "sha256:e51f43d2c25c2ddf02d7f7a5045f71fc1d5cb5ad04ef6db20da961c077654b9b"}, {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"},
{file = "jupyter_events-0.9.1.tar.gz", hash = "sha256:a52e86f59eb317ee71ff2d7500c94b963b8a24f0b7a1517e2e653e24258e15c7"}, {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"},
] ]
[package.dependencies] [package.dependencies]
@ -1216,13 +1216,13 @@ url = "../text-splitters"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.27" version = "0.1.31"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.27-py3-none-any.whl", hash = "sha256:d223176952b1525c958189ab1b894f5bd9891ec9177222f7a978aeee4bf1cc95"}, {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"},
{file = "langsmith-0.1.27.tar.gz", hash = "sha256:e0a339d976362051adf3fdbc43fcc7c00bb4615a401321ad7e556bd2dab556c0"}, {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"},
] ]
[package.dependencies] [package.dependencies]
@ -1406,13 +1406,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=
[[package]] [[package]]
name = "nbconvert" name = "nbconvert"
version = "7.16.2" version = "7.16.3"
description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "nbconvert-7.16.2-py3-none-any.whl", hash = "sha256:0c01c23981a8de0220255706822c40b751438e32467d6a686e26be08ba784382"}, {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"},
{file = "nbconvert-7.16.2.tar.gz", hash = "sha256:8310edd41e1c43947e4ecf16614c61469ebc024898eb808cce0999860fc9fb16"}, {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"},
] ]
[package.dependencies] [package.dependencies]
@ -1439,7 +1439,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp
qtpdf = ["nbconvert[qtpng]"] qtpdf = ["nbconvert[qtpng]"]
qtpng = ["pyqtwebengine (>=5.15)"] qtpng = ["pyqtwebengine (>=5.15)"]
serve = ["tornado (>=6.1)"] serve = ["tornado (>=6.1)"]
test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"]
webpdf = ["playwright"] webpdf = ["playwright"]
[[package]] [[package]]
@ -1997,17 +1997,17 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy
[[package]] [[package]]
name = "pytest-mock" name = "pytest-mock"
version = "3.12.0" version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest" description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
{file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
] ]
[package.dependencies] [package.dependencies]
pytest = ">=5.0" pytest = ">=6.2.5"
[package.extras] [package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"] dev = ["pre-commit", "pytest-asyncio", "tox"]
@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "ca611429e3dd84ce6dac7ef69d7d9b4da78bf467356946e37016b821e5fe752e" content-hash = "a13a0a8454b242106bb681fa74e1f1320a0198f2e07b35d29d985b03a310cf67"

View File

@ -19,6 +19,7 @@ PyYAML = ">=5.3"
requests = "^2" requests = "^2"
packaging = "^23.2" packaging = "^23.2"
jinja2 = { version = "^3", optional = true } jinja2 = { version = "^3", optional = true }
defusedxml = "^0.7"
[tool.poetry.group.lint] [tool.poetry.group.lint]
optional = true optional = true

View File

@ -1,4 +1,7 @@
"""Test XMLOutputParser""" """Test XMLOutputParser"""
from typing import AsyncIterator
from xml.etree.ElementTree import ParseError
import pytest import pytest
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -40,19 +43,29 @@ More random text
""", """,
], ],
) )
def test_xml_output_parser(result: str) -> None: async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser.""" """Test XMLOutputParser."""
xml_parser = XMLOutputParser() xml_parser = XMLOutputParser()
assert DEF_RESULT_EXPECTED == xml_parser.parse(result)
xml_result = xml_parser.parse(result) assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result))
assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [ assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]}, {"foo": [{"baz": "tag"}]},
] ]
async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c
chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
assert chunks == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"]) @pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None: def test_xml_output_parser_fail(result: str) -> None:
@ -63,3 +76,41 @@ def test_xml_output_parser_fail(result: str) -> None:
with pytest.raises(OutputParserException) as e: with pytest.raises(OutputParserException) as e:
xml_parser.parse(result) xml_parser.parse(result)
assert "Failed to parse" in str(e) assert "Failed to parse" in str(e)
MALICIOUS_XML = """<?xml version="1.0"?>
<!DOCTYPE lolz [<!ENTITY lol "lol"><!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
<!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">
<!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">
<!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">
<!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">
<!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">
]>
<lolz>&lol9;</lolz>"""
async def tests_billion_laughs_attack() -> None:
parser = XMLOutputParser()
with pytest.raises(OutputParserException):
parser.parse(MALICIOUS_XML)
with pytest.raises(OutputParserException):
await parser.aparse(MALICIOUS_XML)
with pytest.raises(ParseError):
# Right now raises undefined entity error
assert list(parser.transform(iter(MALICIOUS_XML))) == [
{"foo": [{"bar": [{"baz": None}]}]}
]
async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c
with pytest.raises(ParseError):
chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))]
assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]