diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py
index 40d72953d48..704c67b8e79 100644
--- a/libs/core/langchain_core/output_parsers/xml.py
+++ b/libs/core/langchain_core/output_parsers/xml.py
@@ -1,7 +1,8 @@
import re
import xml
import xml.etree.ElementTree as ET
-from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
+from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
+from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@@ -24,6 +25,105 @@ Here are the output tags:
```""" # noqa: E501
+class _StreamingParser:
+ """Streaming parser for XML.
+
+ This implementation is pulled into a class to avoid implementation
+ drift between transform and atransform of the XMLOutputParser.
+ """
+
+ def __init__(self, parser: Literal["defusedxml", "xml"]) -> None:
+ """Initialize the streaming parser.
+
+ Args:
+ parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
+ See documentation in XMLOutputParser for more information.
+ """
+ if parser == "defusedxml":
+ try:
+ from defusedxml import ElementTree as DET # type: ignore
+ except ImportError:
+ raise ImportError(
+ "defusedxml is not installed. "
+ "Please install it to use the defusedxml parser."
+ "You can install it with `pip install defusedxml` "
+ )
+ _parser = DET.DefusedXMLParser(target=TreeBuilder())
+ else:
+ _parser = None
+ self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
+ self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
+ self.current_path: List[str] = []
+ self.current_path_has_children = False
+ self.buffer = ""
+ self.xml_started = False
+
+ def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]:
+ """Parse a chunk of text.
+
+ Args:
+ chunk: A chunk of text to parse. This can be a string or a BaseMessage.
+
+ Yields:
+ AddableDict: A dictionary representing the parsed XML element.
+ """
+ if isinstance(chunk, BaseMessage):
+ # extract text
+ chunk_content = chunk.content
+ if not isinstance(chunk_content, str):
+ # ignore non-string messages (e.g., function calls)
+ return
+ chunk = chunk_content
+ # add chunk to buffer of unprocessed text
+ self.buffer += chunk
+ # if xml string hasn't started yet, continue to next chunk
+ if not self.xml_started:
+ if match := self.xml_start_re.search(self.buffer):
+ # if xml string has started, remove all text before it
+ self.buffer = self.buffer[match.start() :]
+ self.xml_started = True
+ else:
+ return
+ # feed buffer to parser
+ self.pull_parser.feed(self.buffer)
+ self.buffer = ""
+ # yield all events
+ try:
+ for event, elem in self.pull_parser.read_events():
+ if event == "start":
+ # update current path
+ self.current_path.append(elem.tag)
+ self.current_path_has_children = False
+ elif event == "end":
+ # remove last element from current path
+ #
+ self.current_path.pop()
+ # yield element
+ if not self.current_path_has_children:
+ yield nested_element(self.current_path, elem)
+ # prevent yielding of parent element
+ if self.current_path:
+ self.current_path_has_children = True
+ else:
+ self.xml_started = False
+ except xml.etree.ElementTree.ParseError:
+ # This might be junk at the end of the XML input.
+ # Let's check whether the current path is empty.
+ if not self.current_path:
+ # If it is empty, we can ignore this error.
+ return
+ else:
+ raise
+
+ def close(self) -> None:
+ """Close the parser."""
+ try:
+ self.pull_parser.close()
+ except xml.etree.ElementTree.ParseError:
+ # Ignore. This will ignore any incomplete XML at the end of the input
+ pass
+
+
class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""
@@ -31,12 +131,48 @@ class XMLOutputParser(BaseTransformOutputParser):
encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
)
+ parser: Literal["defusedxml", "xml"] = "defusedxml"
+ """Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
+
+ * 'defusedxml' is the default parser and is used to prevent XML vulnerabilities
+ present in some distributions of Python's standard library xml.
+ `defusedxml` is a wrapper around the standard library parser that
+ sets up the parser with secure defaults.
+ * 'xml' is the standard library parser.
+
+ Use `xml` only if you are sure that your distribution of the standard library
+ is not vulnerable to XML vulnerabilities.
+
+ Please review the following resources for more information:
+
+ * https://docs.python.org/3/library/xml.html#xml-vulnerabilities
+ * https://github.com/tiran/defusedxml
+
+ The standard library relies on libexpat for parsing XML:
+ https://github.com/libexpat/libexpat
+ """
def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
# Try to find XML string within triple backticks
+ # Imports are temporarily placed here to avoid issue with caching on CI
+ # likely if you're reading this you can move them to the top of the file
+ if self.parser == "defusedxml":
+ try:
+ from defusedxml import ElementTree as DET # type: ignore
+ except ImportError:
+ raise ImportError(
+ "defusedxml is not installed. "
+ "Please install it to use the defusedxml parser."
+ "You can install it with `pip install defusedxml`"
+ "See https://github.com/tiran/defusedxml for more details"
+ )
+ _ET = DET # Use the defusedxml parser
+ else:
+ _ET = ET # Use the standard library parser
+
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
# If match found, use the content within the backticks
@@ -57,132 +193,19 @@ class XMLOutputParser(BaseTransformOutputParser):
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
- xml_start_re = re.compile(r"<[a-zA-Z:_]")
- parser = ET.XMLPullParser(["start", "end"])
- xml_started = False
- current_path: List[str] = []
- current_path_has_children = False
- buffer = ""
+ streaming_parser = _StreamingParser(self.parser)
for chunk in input:
- if isinstance(chunk, BaseMessage):
- # extract text
- chunk_content = chunk.content
- if not isinstance(chunk_content, str):
- continue
- chunk = chunk_content
- # add chunk to buffer of unprocessed text
- buffer += chunk
- # if xml string hasn't started yet, continue to next chunk
- if not xml_started:
- if match := xml_start_re.search(buffer):
- # if xml string has started, remove all text before it
- buffer = buffer[match.start() :]
- xml_started = True
- else:
- continue
- # feed buffer to parser
- parser.feed(buffer)
-
- buffer = ""
- # yield all events
- try:
- for event, elem in parser.read_events():
- if event == "start":
- # update current path
- current_path.append(elem.tag)
- current_path_has_children = False
- elif event == "end":
- # remove last element from current path
- #
- current_path.pop()
- # yield element
- if not current_path_has_children:
- yield nested_element(current_path, elem)
- # prevent yielding of parent element
- if current_path:
- current_path_has_children = True
- else:
- xml_started = False
- except xml.etree.ElementTree.ParseError:
- # This might be junk at the end of the XML input.
- # Let's check whether the current path is empty.
- if not current_path:
- # If it is empty, we can ignore this error.
- break
- else:
- raise
-
- # close parser
- try:
- parser.close()
- except xml.etree.ElementTree.ParseError:
- # Ignore. This will ignore any incomplete XML at the end of the input
- pass
+ yield from streaming_parser.parse(chunk)
+ streaming_parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
- xml_start_re = re.compile(r"<[a-zA-Z:_]")
- parser = ET.XMLPullParser(["start", "end"])
- xml_started = False
- current_path: List[str] = []
- current_path_has_children = False
- buffer = ""
+ streaming_parser = _StreamingParser(self.parser)
async for chunk in input:
- if isinstance(chunk, BaseMessage):
- # extract text
- chunk_content = chunk.content
- if not isinstance(chunk_content, str):
- continue
- chunk = chunk_content
- # add chunk to buffer of unprocessed text
- buffer += chunk
- # if xml string hasn't started yet, continue to next chunk
- if not xml_started:
- if match := xml_start_re.search(buffer):
- # if xml string has started, remove all text before it
- buffer = buffer[match.start() :]
- xml_started = True
- else:
- continue
- # feed buffer to parser
- parser.feed(buffer)
-
- buffer = ""
- # yield all events
- try:
- for event, elem in parser.read_events():
- if event == "start":
- # update current path
- current_path.append(elem.tag)
- current_path_has_children = False
- elif event == "end":
- # remove last element from current path
- #
- current_path.pop()
- # yield element
- if not current_path_has_children:
- yield nested_element(current_path, elem)
- # prevent yielding of parent element
- if current_path:
- current_path_has_children = True
- else:
- xml_started = False
- except xml.etree.ElementTree.ParseError:
- # This might be junk at the end of the XML input.
- # Let's check whether the current path is empty.
- if not current_path:
- # If it is empty, we can ignore this error.
- break
- else:
- raise
-
- # close parser
- try:
- parser.close()
- except xml.etree.ElementTree.ParseError:
- # Ignore. This will ignore any incomplete XML at the end of the input
- pass
+ for output in streaming_parser.parse(chunk):
+ yield output
+ streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
index 48e7372b98a..c30d09ea1b1 100644
--- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
+++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -1,4 +1,5 @@
"""Test XMLOutputParser"""
+import importlib
from typing import AsyncIterator, Iterable
import pytest
@@ -42,24 +43,12 @@ DEF_RESULT_EXPECTED = {
}
-@pytest.mark.parametrize(
- "result",
- [
- DATA, # has no xml header
- WITH_XML_HEADER,
- IN_XML_TAGS_WITH_XML_HEADER,
- IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
- ],
-)
-async def test_xml_output_parser(result: str) -> None:
- """Test XMLOutputParser."""
+async def _test_parser(parser: XMLOutputParser, content: str) -> None:
+ """Test parser."""
+ xml_content = parser.parse(content)
+ assert DEF_RESULT_EXPECTED == xml_content
- xml_parser = XMLOutputParser()
-
- xml_result = xml_parser.parse(result)
- assert DEF_RESULT_EXPECTED == xml_result
-
- assert list(xml_parser.transform(iter(result))) == [
+ assert list(parser.transform(iter(content))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
@@ -69,7 +58,7 @@ async def test_xml_output_parser(result: str) -> None:
for item in iterable:
yield item
- chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
+ chunks = [chunk async for chunk in parser.atransform(_as_iter(content))]
assert list(chunks) == [
{"foo": [{"bar": [{"baz": None}]}]},
@@ -78,12 +67,72 @@ async def test_xml_output_parser(result: str) -> None:
]
+@pytest.mark.parametrize(
+ "content",
+ [
+ DATA, # has no xml header
+ WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
+ ],
+)
+async def test_xml_output_parser(content: str) -> None:
+ """Test XMLOutputParser."""
+ xml_parser = XMLOutputParser(parser="xml")
+ await _test_parser(xml_parser, content)
+
+
+@pytest.mark.skipif(
+ importlib.util.find_spec("defusedxml") is None,
+ reason="defusedxml is not installed",
+)
+@pytest.mark.parametrize(
+ "content",
+ [
+ DATA, # has no xml header
+ WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
+ ],
+)
+async def test_xml_output_parser_defused(content: str) -> None:
+ """Test XMLOutputParser."""
+ xml_parser = XMLOutputParser(parser="defusedxml")
+ await _test_parser(xml_parser, content)
+
+
@pytest.mark.parametrize("result", ["foo>", " None:
"""Test XMLOutputParser where complete output is not in XML format."""
- xml_parser = XMLOutputParser()
+ xml_parser = XMLOutputParser(parser="xml")
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)
+
+
+MALICIOUS_XML = """
+
+
+
+
+
+
+
+
+
+
+]>
+&lol9;"""
+
+
+async def tests_billion_laughs_attack() -> None:
+ # Testing with standard XML parser since it's safe to use in
+ # newer versions of Python
+ parser = XMLOutputParser(parser="xml")
+ with pytest.raises(OutputParserException):
+ parser.parse(MALICIOUS_XML)
+
+ with pytest.raises(OutputParserException):
+ await parser.aparse(MALICIOUS_XML)