mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
core[patch]: Patch XML vulnerability in XMLOutputParser (CVE-2024-1455) (#19653)
Patch potential XML vulnerability CVE-2024-1455 This patches a potential XML vulnerability in the XMLOutputParser in langchain-core. The vulnerability in some situations could lead to a denial of service attack. At risk are users that: 1) Running older distributions of python that have older version of libexpat 2) Are using XMLOutputParser with an agent 3) Accept inputs from untrusted sources with this agent (e.g., endpoint on the web that allows an untrusted user to interact wiith the parser)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import re
|
||||
import xml
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
|
||||
from xml.etree.ElementTree import TreeBuilder
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage
|
||||
@@ -24,6 +25,105 @@ Here are the output tags:
|
||||
```""" # noqa: E501
|
||||
|
||||
|
||||
class _StreamingParser:
|
||||
"""Streaming parser for XML.
|
||||
|
||||
This implementation is pulled into a class to avoid implementation
|
||||
drift between transform and atransform of the XMLOutputParser.
|
||||
"""
|
||||
|
||||
def __init__(self, parser: Literal["defusedxml", "xml"]) -> None:
|
||||
"""Initialize the streaming parser.
|
||||
|
||||
Args:
|
||||
parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
|
||||
See documentation in XMLOutputParser for more information.
|
||||
"""
|
||||
if parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml` "
|
||||
)
|
||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
||||
else:
|
||||
_parser = None
|
||||
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
||||
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
||||
self.current_path: List[str] = []
|
||||
self.current_path_has_children = False
|
||||
self.buffer = ""
|
||||
self.xml_started = False
|
||||
|
||||
def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]:
|
||||
"""Parse a chunk of text.
|
||||
|
||||
Args:
|
||||
chunk: A chunk of text to parse. This can be a string or a BaseMessage.
|
||||
|
||||
Yields:
|
||||
AddableDict: A dictionary representing the parsed XML element.
|
||||
"""
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
# ignore non-string messages (e.g., function calls)
|
||||
return
|
||||
chunk = chunk_content
|
||||
# add chunk to buffer of unprocessed text
|
||||
self.buffer += chunk
|
||||
# if xml string hasn't started yet, continue to next chunk
|
||||
if not self.xml_started:
|
||||
if match := self.xml_start_re.search(self.buffer):
|
||||
# if xml string has started, remove all text before it
|
||||
self.buffer = self.buffer[match.start() :]
|
||||
self.xml_started = True
|
||||
else:
|
||||
return
|
||||
# feed buffer to parser
|
||||
self.pull_parser.feed(self.buffer)
|
||||
self.buffer = ""
|
||||
# yield all events
|
||||
try:
|
||||
for event, elem in self.pull_parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
self.current_path.append(elem.tag)
|
||||
self.current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
#
|
||||
self.current_path.pop()
|
||||
# yield element
|
||||
if not self.current_path_has_children:
|
||||
yield nested_element(self.current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
if self.current_path:
|
||||
self.current_path_has_children = True
|
||||
else:
|
||||
self.xml_started = False
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# This might be junk at the end of the XML input.
|
||||
# Let's check whether the current path is empty.
|
||||
if not self.current_path:
|
||||
# If it is empty, we can ignore this error.
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the parser."""
|
||||
try:
|
||||
self.pull_parser.close()
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# Ignore. This will ignore any incomplete XML at the end of the input
|
||||
pass
|
||||
|
||||
|
||||
class XMLOutputParser(BaseTransformOutputParser):
|
||||
"""Parse an output using xml format."""
|
||||
|
||||
@@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
encoding_matcher: re.Pattern = re.compile(
|
||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
parser: Literal["defusedxml", "xml"] = "defusedxml"
|
||||
"""Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
|
||||
|
||||
* 'defusedxml' is the default parser and is used to prevent XML vulnerabilities
|
||||
present in some distributions of Python's standard library xml.
|
||||
`defusedxml` is a wrapper around the standard library parser that
|
||||
sets up the parser with secure defaults.
|
||||
* 'xml' is the standard library parser.
|
||||
|
||||
Use `xml` only if you are sure that your distribution of the standard library
|
||||
is not vulnerable to XML vulnerabilities.
|
||||
|
||||
Please review the following resources for more information:
|
||||
|
||||
* https://docs.python.org/3/library/xml.html#xml-vulnerabilities
|
||||
* https://github.com/tiran/defusedxml
|
||||
|
||||
The standard library relies on libexpat for parsing XML:
|
||||
https://github.com/libexpat/libexpat
|
||||
"""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||
# Try to find XML string within triple backticks
|
||||
# Imports are temporarily placed here to avoid issue with caching on CI
|
||||
# likely if you're reading this you can move them to the top of the file
|
||||
if self.parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml`"
|
||||
"See https://github.com/tiran/defusedxml for more details"
|
||||
)
|
||||
_ET = DET # Use the defusedxml parser
|
||||
else:
|
||||
_ET = ET # Use the standard library parser
|
||||
|
||||
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
|
||||
if match is not None:
|
||||
# If match found, use the content within the backticks
|
||||
@@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
def _transform(
|
||||
self, input: Iterator[Union[str, BaseMessage]]
|
||||
) -> Iterator[AddableDict]:
|
||||
xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
xml_started = False
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
buffer = ""
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# add chunk to buffer of unprocessed text
|
||||
buffer += chunk
|
||||
# if xml string hasn't started yet, continue to next chunk
|
||||
if not xml_started:
|
||||
if match := xml_start_re.search(buffer):
|
||||
# if xml string has started, remove all text before it
|
||||
buffer = buffer[match.start() :]
|
||||
xml_started = True
|
||||
else:
|
||||
continue
|
||||
# feed buffer to parser
|
||||
parser.feed(buffer)
|
||||
|
||||
buffer = ""
|
||||
# yield all events
|
||||
try:
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
#
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
if current_path:
|
||||
current_path_has_children = True
|
||||
else:
|
||||
xml_started = False
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# This might be junk at the end of the XML input.
|
||||
# Let's check whether the current path is empty.
|
||||
if not current_path:
|
||||
# If it is empty, we can ignore this error.
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
# close parser
|
||||
try:
|
||||
parser.close()
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# Ignore. This will ignore any incomplete XML at the end of the input
|
||||
pass
|
||||
yield from streaming_parser.parse(chunk)
|
||||
streaming_parser.close()
|
||||
|
||||
async def _atransform(
|
||||
self, input: AsyncIterator[Union[str, BaseMessage]]
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
||||
parser = ET.XMLPullParser(["start", "end"])
|
||||
xml_started = False
|
||||
current_path: List[str] = []
|
||||
current_path_has_children = False
|
||||
buffer = ""
|
||||
streaming_parser = _StreamingParser(self.parser)
|
||||
async for chunk in input:
|
||||
if isinstance(chunk, BaseMessage):
|
||||
# extract text
|
||||
chunk_content = chunk.content
|
||||
if not isinstance(chunk_content, str):
|
||||
continue
|
||||
chunk = chunk_content
|
||||
# add chunk to buffer of unprocessed text
|
||||
buffer += chunk
|
||||
# if xml string hasn't started yet, continue to next chunk
|
||||
if not xml_started:
|
||||
if match := xml_start_re.search(buffer):
|
||||
# if xml string has started, remove all text before it
|
||||
buffer = buffer[match.start() :]
|
||||
xml_started = True
|
||||
else:
|
||||
continue
|
||||
# feed buffer to parser
|
||||
parser.feed(buffer)
|
||||
|
||||
buffer = ""
|
||||
# yield all events
|
||||
try:
|
||||
for event, elem in parser.read_events():
|
||||
if event == "start":
|
||||
# update current path
|
||||
current_path.append(elem.tag)
|
||||
current_path_has_children = False
|
||||
elif event == "end":
|
||||
# remove last element from current path
|
||||
#
|
||||
current_path.pop()
|
||||
# yield element
|
||||
if not current_path_has_children:
|
||||
yield nested_element(current_path, elem)
|
||||
# prevent yielding of parent element
|
||||
if current_path:
|
||||
current_path_has_children = True
|
||||
else:
|
||||
xml_started = False
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# This might be junk at the end of the XML input.
|
||||
# Let's check whether the current path is empty.
|
||||
if not current_path:
|
||||
# If it is empty, we can ignore this error.
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
# close parser
|
||||
try:
|
||||
parser.close()
|
||||
except xml.etree.ElementTree.ParseError:
|
||||
# Ignore. This will ignore any incomplete XML at the end of the input
|
||||
pass
|
||||
for output in streaming_parser.parse(chunk):
|
||||
yield output
|
||||
streaming_parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
||||
"""Converts xml tree to python dictionary."""
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Test XMLOutputParser"""
|
||||
import importlib
|
||||
from typing import AsyncIterator, Iterable
|
||||
|
||||
import pytest
|
||||
@@ -42,24 +43,12 @@ DEF_RESULT_EXPECTED = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"result",
|
||||
[
|
||||
DATA, # has no xml header
|
||||
WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
|
||||
],
|
||||
)
|
||||
async def test_xml_output_parser(result: str) -> None:
|
||||
"""Test XMLOutputParser."""
|
||||
async def _test_parser(parser: XMLOutputParser, content: str) -> None:
|
||||
"""Test parser."""
|
||||
xml_content = parser.parse(content)
|
||||
assert DEF_RESULT_EXPECTED == xml_content
|
||||
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
xml_result = xml_parser.parse(result)
|
||||
assert DEF_RESULT_EXPECTED == xml_result
|
||||
|
||||
assert list(xml_parser.transform(iter(result))) == [
|
||||
assert list(parser.transform(iter(content))) == [
|
||||
{"foo": [{"bar": [{"baz": None}]}]},
|
||||
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
|
||||
{"foo": [{"baz": "tag"}]},
|
||||
@@ -69,7 +58,7 @@ async def test_xml_output_parser(result: str) -> None:
|
||||
for item in iterable:
|
||||
yield item
|
||||
|
||||
chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
|
||||
chunks = [chunk async for chunk in parser.atransform(_as_iter(content))]
|
||||
|
||||
assert list(chunks) == [
|
||||
{"foo": [{"bar": [{"baz": None}]}]},
|
||||
@@ -78,12 +67,72 @@ async def test_xml_output_parser(result: str) -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
DATA, # has no xml header
|
||||
WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
|
||||
],
|
||||
)
|
||||
async def test_xml_output_parser(content: str) -> None:
|
||||
"""Test XMLOutputParser."""
|
||||
xml_parser = XMLOutputParser(parser="xml")
|
||||
await _test_parser(xml_parser, content)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("defusedxml") is None,
|
||||
reason="defusedxml is not installed",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
DATA, # has no xml header
|
||||
WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_XML_HEADER,
|
||||
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
|
||||
],
|
||||
)
|
||||
async def test_xml_output_parser_defused(content: str) -> None:
|
||||
"""Test XMLOutputParser."""
|
||||
xml_parser = XMLOutputParser(parser="defusedxml")
|
||||
await _test_parser(xml_parser, content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
|
||||
def test_xml_output_parser_fail(result: str) -> None:
|
||||
"""Test XMLOutputParser where complete output is not in XML format."""
|
||||
|
||||
xml_parser = XMLOutputParser()
|
||||
xml_parser = XMLOutputParser(parser="xml")
|
||||
|
||||
with pytest.raises(OutputParserException) as e:
|
||||
xml_parser.parse(result)
|
||||
assert "Failed to parse" in str(e)
|
||||
|
||||
|
||||
MALICIOUS_XML = """<?xml version="1.0"?>
|
||||
<!DOCTYPE lolz [<!ENTITY lol "lol"><!ELEMENT lolz (#PCDATA)>
|
||||
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
|
||||
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
|
||||
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
|
||||
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
|
||||
<!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">
|
||||
<!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">
|
||||
<!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">
|
||||
<!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">
|
||||
<!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">
|
||||
]>
|
||||
<lolz>&lol9;</lolz>"""
|
||||
|
||||
|
||||
async def tests_billion_laughs_attack() -> None:
|
||||
# Testing with standard XML parser since it's safe to use in
|
||||
# newer versions of Python
|
||||
parser = XMLOutputParser(parser="xml")
|
||||
with pytest.raises(OutputParserException):
|
||||
parser.parse(MALICIOUS_XML)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
await parser.aparse(MALICIOUS_XML)
|
||||
|
Reference in New Issue
Block a user