diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 40d72953d48..704c67b8e79 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,7 +1,8 @@ import re import xml import xml.etree.ElementTree as ET -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union +from xml.etree.ElementTree import TreeBuilder from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -24,6 +25,105 @@ Here are the output tags: ```""" # noqa: E501 +class _StreamingParser: + """Streaming parser for XML. + + This implementation is pulled into a class to avoid implementation + drift between transform and atransform of the XMLOutputParser. + """ + + def __init__(self, parser: Literal["defusedxml", "xml"]) -> None: + """Initialize the streaming parser. + + Args: + parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. + See documentation in XMLOutputParser for more information. + """ + if parser == "defusedxml": + try: + from defusedxml import ElementTree as DET # type: ignore + except ImportError: + raise ImportError( + "defusedxml is not installed. " + "Please install it to use the defusedxml parser." + "You can install it with `pip install defusedxml` " + ) + _parser = DET.DefusedXMLParser(target=TreeBuilder()) + else: + _parser = None + self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) + self.xml_start_re = re.compile(r"<[a-zA-Z:_]") + self.current_path: List[str] = [] + self.current_path_has_children = False + self.buffer = "" + self.xml_started = False + + def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]: + """Parse a chunk of text. + + Args: + chunk: A chunk of text to parse. This can be a string or a BaseMessage. + + Yields: + AddableDict: A dictionary representing the parsed XML element. + """ + if isinstance(chunk, BaseMessage): + # extract text + chunk_content = chunk.content + if not isinstance(chunk_content, str): + # ignore non-string messages (e.g., function calls) + return + chunk = chunk_content + # add chunk to buffer of unprocessed text + self.buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not self.xml_started: + if match := self.xml_start_re.search(self.buffer): + # if xml string has started, remove all text before it + self.buffer = self.buffer[match.start() :] + self.xml_started = True + else: + return + # feed buffer to parser + self.pull_parser.feed(self.buffer) + self.buffer = "" + # yield all events + try: + for event, elem in self.pull_parser.read_events(): + if event == "start": + # update current path + self.current_path.append(elem.tag) + self.current_path_has_children = False + elif event == "end": + # remove last element from current path + # + self.current_path.pop() + # yield element + if not self.current_path_has_children: + yield nested_element(self.current_path, elem) + # prevent yielding of parent element + if self.current_path: + self.current_path_has_children = True + else: + self.xml_started = False + except xml.etree.ElementTree.ParseError: + # This might be junk at the end of the XML input. + # Let's check whether the current path is empty. + if not self.current_path: + # If it is empty, we can ignore this error. + return + else: + raise + + def close(self) -> None: + """Close the parser.""" + try: + self.pull_parser.close() + except xml.etree.ElementTree.ParseError: + # Ignore. This will ignore any incomplete XML at the end of the input + pass + + class XMLOutputParser(BaseTransformOutputParser): """Parse an output using xml format.""" @@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser): encoding_matcher: re.Pattern = re.compile( r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL ) + parser: Literal["defusedxml", "xml"] = "defusedxml" + """Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'. + + * 'defusedxml' is the default parser and is used to prevent XML vulnerabilities + present in some distributions of Python's standard library xml. + `defusedxml` is a wrapper around the standard library parser that + sets up the parser with secure defaults. + * 'xml' is the standard library parser. + + Use `xml` only if you are sure that your distribution of the standard library + is not vulnerable to XML vulnerabilities. + + Please review the following resources for more information: + + * https://docs.python.org/3/library/xml.html#xml-vulnerabilities + * https://github.com/tiran/defusedxml + + The standard library relies on libexpat for parsing XML: + https://github.com/libexpat/libexpat + """ def get_format_instructions(self) -> str: return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: # Try to find XML string within triple backticks + # Imports are temporarily placed here to avoid issue with caching on CI + # likely if you're reading this you can move them to the top of the file + if self.parser == "defusedxml": + try: + from defusedxml import ElementTree as DET # type: ignore + except ImportError: + raise ImportError( + "defusedxml is not installed. " + "Please install it to use the defusedxml parser." + "You can install it with `pip install defusedxml`" + "See https://github.com/tiran/defusedxml for more details" + ) + _ET = DET # Use the defusedxml parser + else: + _ET = ET # Use the standard library parser + match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: # If match found, use the content within the backticks @@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser): def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: - xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) - xml_started = False - current_path: List[str] = [] - current_path_has_children = False - buffer = "" + streaming_parser = _StreamingParser(self.parser) for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - - buffer = "" - # yield all events - try: - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - # - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False - except xml.etree.ElementTree.ParseError: - # This might be junk at the end of the XML input. - # Let's check whether the current path is empty. - if not current_path: - # If it is empty, we can ignore this error. - break - else: - raise - - # close parser - try: - parser.close() - except xml.etree.ElementTree.ParseError: - # Ignore. This will ignore any incomplete XML at the end of the input - pass + yield from streaming_parser.parse(chunk) + streaming_parser.close() async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: - xml_start_re = re.compile(r"<[a-zA-Z:_]") - parser = ET.XMLPullParser(["start", "end"]) - xml_started = False - current_path: List[str] = [] - current_path_has_children = False - buffer = "" + streaming_parser = _StreamingParser(self.parser) async for chunk in input: - if isinstance(chunk, BaseMessage): - # extract text - chunk_content = chunk.content - if not isinstance(chunk_content, str): - continue - chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - - buffer = "" - # yield all events - try: - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - # - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False - except xml.etree.ElementTree.ParseError: - # This might be junk at the end of the XML input. - # Let's check whether the current path is empty. - if not current_path: - # If it is empty, we can ignore this error. - break - else: - raise - - # close parser - try: - parser.close() - except xml.etree.ElementTree.ParseError: - # Ignore. This will ignore any incomplete XML at the end of the input - pass + for output in streaming_parser.parse(chunk): + yield output + streaming_parser.close() def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 48e7372b98a..c30d09ea1b1 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,4 +1,5 @@ """Test XMLOutputParser""" +import importlib from typing import AsyncIterator, Iterable import pytest @@ -42,24 +43,12 @@ DEF_RESULT_EXPECTED = { } -@pytest.mark.parametrize( - "result", - [ - DATA, # has no xml header - WITH_XML_HEADER, - IN_XML_TAGS_WITH_XML_HEADER, - IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, - ], -) -async def test_xml_output_parser(result: str) -> None: - """Test XMLOutputParser.""" +async def _test_parser(parser: XMLOutputParser, content: str) -> None: + """Test parser.""" + xml_content = parser.parse(content) + assert DEF_RESULT_EXPECTED == xml_content - xml_parser = XMLOutputParser() - - xml_result = xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == xml_result - - assert list(xml_parser.transform(iter(result))) == [ + assert list(parser.transform(iter(content))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, @@ -69,7 +58,7 @@ async def test_xml_output_parser(result: str) -> None: for item in iterable: yield item - chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] + chunks = [chunk async for chunk in parser.atransform(_as_iter(content))] assert list(chunks) == [ {"foo": [{"bar": [{"baz": None}]}]}, @@ -78,12 +67,72 @@ async def test_xml_output_parser(result: str) -> None: ] +@pytest.mark.parametrize( + "content", + [ + DATA, # has no xml header + WITH_XML_HEADER, + IN_XML_TAGS_WITH_XML_HEADER, + IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, + ], +) +async def test_xml_output_parser(content: str) -> None: + """Test XMLOutputParser.""" + xml_parser = XMLOutputParser(parser="xml") + await _test_parser(xml_parser, content) + + +@pytest.mark.skipif( + importlib.util.find_spec("defusedxml") is None, + reason="defusedxml is not installed", +) +@pytest.mark.parametrize( + "content", + [ + DATA, # has no xml header + WITH_XML_HEADER, + IN_XML_TAGS_WITH_XML_HEADER, + IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, + ], +) +async def test_xml_output_parser_defused(content: str) -> None: + """Test XMLOutputParser.""" + xml_parser = XMLOutputParser(parser="defusedxml") + await _test_parser(xml_parser, content) + + @pytest.mark.parametrize("result", ["foo>", " None: """Test XMLOutputParser where complete output is not in XML format.""" - xml_parser = XMLOutputParser() + xml_parser = XMLOutputParser(parser="xml") with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) + + +MALICIOUS_XML = """ + + + + + + + + + + +]> +&lol9;""" + + +async def tests_billion_laughs_attack() -> None: + # Testing with standard XML parser since it's safe to use in + # newer versions of Python + parser = XMLOutputParser(parser="xml") + with pytest.raises(OutputParserException): + parser.parse(MALICIOUS_XML) + + with pytest.raises(OutputParserException): + await parser.aparse(MALICIOUS_XML)