From e8339b1d831199b6c67182e54623999c96fc3b13 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 27 Mar 2024 12:41:52 -0400 Subject: [PATCH] core[patch]: Patch XML vulnerability in XMLOutputParser (CVE-2024-1455) (#19653) Patch potential XML vulnerability CVE-2024-1455 This patches a potential XML vulnerability in the XMLOutputParser in langchain-core. The vulnerability in some situations could lead to a denial of service attack. At risk are users that: 1) Running older distributions of python that have older version of libexpat 2) Are using XMLOutputParser with an agent 3) Accept inputs from untrusted sources with this agent (e.g., endpoint on the web that allows an untrusted user to interact wiith the parser) --- .../core/langchain_core/output_parsers/xml.py | 265 ++++++++++-------- .../output_parsers/test_xml_parser.py | 87 ++++-- 2 files changed, 212 insertions(+), 140 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 40d72953d48..704c67b8e79 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,7 +1,8 @@ import re import xml import xml.etree.ElementTree as ET -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union +from xml.etree.ElementTree import TreeBuilder from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -24,6 +25,105 @@ Here are the output tags: ```""" # noqa: E501 +class _StreamingParser: + """Streaming parser for XML. + + This implementation is pulled into a class to avoid implementation + drift between transform and atransform of the XMLOutputParser. + """ + + def __init__(self, parser: Literal["defusedxml", "xml"]) -> None: + """Initialize the streaming parser. + + Args: + parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. + See documentation in XMLOutputParser for more information. + """ + if parser == "defusedxml": + try: + from defusedxml import ElementTree as DET # type: ignore + except ImportError: + raise ImportError( + "defusedxml is not installed. " + "Please install it to use the defusedxml parser." + "You can install it with `pip install defusedxml` " + ) + _parser = DET.DefusedXMLParser(target=TreeBuilder()) + else: + _parser = None + self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) + self.xml_start_re = re.compile(r"<[a-zA-Z:_]") + self.current_path: List[str] = [] + self.current_path_has_children = False + self.buffer = "" + self.xml_started = False + + def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]: + """Parse a chunk of text. + + Args: + chunk: A chunk of text to parse. This can be a string or a BaseMessage. + + Yields: + AddableDict: A dictionary representing the parsed XML element. + """ + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + # ignore non-string messages (e.g., function calls) + return + chunk = chunk_content + # add chunk to buffer of unprocessed text + self.buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not self.xml_started: + if match := self.xml_start_re.search(self.buffer): + # if xml string has started, remove all text before it + self.buffer = self.buffer[match.start() :] + self.xml_started = True + else: + return + # feed buffer to parser + self.pull_parser.feed(self.buffer) + self.buffer = "" + # yield all events + try: + for event, elem in self.pull_parser.read_events(): + if event == "start": + # update current path + self.current_path.append(elem.tag) + self.current_path_has_children = False + elif event == "end": + # remove last element from current path + # + self.current_path.pop() + # yield element + if not self.current_path_has_children: + yield nested_element(self.current_path, elem) + # prevent yielding of parent element + if self.current_path: + self.current_path_has_children = True + else: + self.xml_started = False + except xml.etree.ElementTree.ParseError: + # This might be junk at the end of the XML input. + # Let's check whether the current path is empty. + if not self.current_path: + # If it is empty, we can ignore this error. + return + else: + raise + + def close(self) -> None: + """Close the parser.""" + try: + self.pull_parser.close() + except xml.etree.ElementTree.ParseError: + # Ignore. This will ignore any incomplete XML at the end of the input + pass + + class XMLOutputParser(BaseTransformOutputParser): """Parse an output using xml format.""" @@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser): encoding_matcher: re.Pattern = re.compile( r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL ) + parser: Literal["defusedxml", "xml"] = "defusedxml" + """Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. + + * 'defusedxml' is the default parser and is used to prevent XML vulnerabilities + present in some distributions of Python's standard library xml. + `defusedxml` is a wrapper around the standard library parser that + sets up the parser with secure defaults. + * 'xml' is the standard library parser. + + Use `xml` only if you are sure that your distribution of the standard library + is not vulnerable to XML vulnerabilities. + + Please review the following resources for more information: + + * https://docs.python.org/3/library/xml.html#xml-vulnerabilities + * https://github.com/tiran/defusedxml + + The standard library relies on libexpat for parsing XML: + https://github.com/libexpat/libexpat + """ def get_format_instructions(self) -> str: return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: # Try to find XML string within triple backticks + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + if self.parser == "defusedxml": + try: + from defusedxml import ElementTree as DET # type: ignore + except ImportError: + raise ImportError( + "defusedxml is not installed. " + "Please install it to use the defusedxml parser." + "You can install it with `pip install defusedxml`" + "See https://github.com/tiran/defusedxml for more details" + ) + _ET = DET # Use the defusedxml parser + else: + _ET = ET # Use the standard library parser + match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: # If match found, use the content within the backticks @@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser): def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: - xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) - xml_started = False - current_path: List[str] = [] - current_path_has_children = False - buffer = "" + streaming_parser = _StreamingParser(self.parser) for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - - buffer = "" - # yield all events - try: - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - # - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False - except xml.etree.ElementTree.ParseError: - # This might be junk at the end of the XML input. - # Let's check whether the current path is empty. - if not current_path: - # If it is empty, we can ignore this error. - break - else: - raise - - # close parser - try: - parser.close() - except xml.etree.ElementTree.ParseError: - # Ignore. This will ignore any incomplete XML at the end of the input - pass + yield from streaming_parser.parse(chunk) + streaming_parser.close() async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: - xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) - xml_started = False - current_path: List[str] = [] - current_path_has_children = False - buffer = "" + streaming_parser = _StreamingParser(self.parser) async for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - - buffer = "" - # yield all events - try: - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - # - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False - except xml.etree.ElementTree.ParseError: - # This might be junk at the end of the XML input. - # Let's check whether the current path is empty. - if not current_path: - # If it is empty, we can ignore this error. - break - else: - raise - - # close parser - try: - parser.close() - except xml.etree.ElementTree.ParseError: - # Ignore. This will ignore any incomplete XML at the end of the input - pass + for output in streaming_parser.parse(chunk): + yield output + streaming_parser.close() def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 48e7372b98a..c30d09ea1b1 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,4 +1,5 @@ """Test XMLOutputParser""" +import importlib from typing import AsyncIterator, Iterable import pytest @@ -42,24 +43,12 @@ DEF_RESULT_EXPECTED = { } -@pytest.mark.parametrize( - "result", - [ - DATA, # has no xml header - WITH_XML_HEADER, - IN_XML_TAGS_WITH_XML_HEADER, - IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, - ], -) -async def test_xml_output_parser(result: str) -> None: - """Test XMLOutputParser.""" +async def _test_parser(parser: XMLOutputParser, content: str) -> None: + """Test parser.""" + xml_content = parser.parse(content) + assert DEF_RESULT_EXPECTED == xml_content - xml_parser = XMLOutputParser() - - xml_result = xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == xml_result - - assert list(xml_parser.transform(iter(result))) == [ + assert list(parser.transform(iter(content))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, @@ -69,7 +58,7 @@ async def test_xml_output_parser(result: str) -> None: for item in iterable: yield item - chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] + chunks = [chunk async for chunk in parser.atransform(_as_iter(content))] assert list(chunks) == [ {"foo": [{"bar": [{"baz": None}]}]}, @@ -78,12 +67,72 @@ async def test_xml_output_parser(result: str) -> None: ] +@pytest.mark.parametrize( + "content", + [ + DATA, # has no xml header + WITH_XML_HEADER, + IN_XML_TAGS_WITH_XML_HEADER, + IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, + ], +) +async def test_xml_output_parser(content: str) -> None: + """Test XMLOutputParser.""" + xml_parser = XMLOutputParser(parser="xml") + await _test_parser(xml_parser, content) + + +@pytest.mark.skipif( + importlib.util.find_spec("defusedxml") is None, + reason="defusedxml is not installed", +) +@pytest.mark.parametrize( + "content", + [ + DATA, # has no xml header + WITH_XML_HEADER, + IN_XML_TAGS_WITH_XML_HEADER, + IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, + ], +) +async def test_xml_output_parser_defused(content: str) -> None: + """Test XMLOutputParser.""" + xml_parser = XMLOutputParser(parser="defusedxml") + await _test_parser(xml_parser, content) + + @pytest.mark.parametrize("result", ["foo>", " None: """Test XMLOutputParser where complete output is not in XML format.""" - xml_parser = XMLOutputParser() + xml_parser = XMLOutputParser(parser="xml") with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) + + +MALICIOUS_XML = """ + + + + + + + + + + +]> +&lol9;""" + + +async def tests_billion_laughs_attack() -> None: + # Testing with standard XML parser since it's safe to use in + # newer versions of Python + parser = XMLOutputParser(parser="xml") + with pytest.raises(OutputParserException): + parser.parse(MALICIOUS_XML) + + with pytest.raises(OutputParserException): + await parser.aparse(MALICIOUS_XML)