diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py
index f74bd4c1050..9871ecee91d 100644
--- a/libs/core/langchain_core/output_parsers/xml.py
+++ b/libs/core/langchain_core/output_parsers/xml.py
@@ -1,6 +1,7 @@
import re
-import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
+from xml.etree import ElementTree as ET
+from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@@ -35,6 +36,10 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
+ # Imports are temporarily placed here to avoid issue with caching on CI
+ # likely if you're reading this you can move them to the top of the file
+ from defusedxml import ElementTree as DET # type: ignore[import]
+
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
@@ -46,18 +51,24 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip()
try:
- root = ET.fromstring(text)
+ root = DET.fromstring(text)
return self._root_to_dict(root)
- except ET.ParseError as e:
+ except (DET.ParseError, DET.EntitiesForbidden) as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
+ # Imports are temporarily placed here to avoid issue with caching on CI
+ # likely if you're reading this you can move them to the top of the file
+ from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
+
+ parser = ET.XMLPullParser(
+ ["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder())
+ )
xml_start_re = re.compile(r"<[a-zA-Z:_]")
- parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
@@ -83,6 +94,61 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer)
buffer = ""
# yield all events
+
+ for event, elem in parser.read_events():
+ if event == "start":
+ # update current path
+ current_path.append(elem.tag)
+ current_path_has_children = False
+ elif event == "end":
+ # remove last element from current path
+ current_path.pop()
+ # yield element
+ if not current_path_has_children:
+ yield nested_element(current_path, elem)
+ # prevent yielding of parent element
+ if current_path:
+ current_path_has_children = True
+ else:
+ xml_started = False
+ # close parser
+ parser.close()
+
+ async def _atransform(
+ self, input: AsyncIterator[Union[str, BaseMessage]]
+ ) -> AsyncIterator[AddableDict]:
+ # Imports are temporarily placed here to avoid issue with caching on CI
+ # likely if you're reading this you can move them to the top of the file
+ from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
+
+ _parser = DefusedXMLParser(target=TreeBuilder())
+ parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
+ xml_start_re = re.compile(r"<[a-zA-Z:_]")
+ xml_started = False
+ current_path: List[str] = []
+ current_path_has_children = False
+ buffer = ""
+ async for chunk in input:
+ if isinstance(chunk, BaseMessage):
+ # extract text
+ chunk_content = chunk.content
+ if not isinstance(chunk_content, str):
+ continue
+ chunk = chunk_content
+ # add chunk to buffer of unprocessed text
+ buffer += chunk
+ # if xml string hasn't started yet, continue to next chunk
+ if not xml_started:
+ if match := xml_start_re.search(buffer):
+ # if xml string has started, remove all text before it
+ buffer = buffer[match.start() :]
+ xml_started = True
+ else:
+ continue
+ # feed buffer to parser
+ parser.feed(buffer)
+ buffer = ""
+ # yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
@@ -102,38 +168,6 @@ class XMLOutputParser(BaseTransformOutputParser):
# close parser
parser.close()
- async def _atransform(
- self, input: AsyncIterator[Union[str, BaseMessage]]
- ) -> AsyncIterator[AddableDict]:
- parser = ET.XMLPullParser(["start", "end"])
- current_path: List[str] = []
- current_path_has_children = False
- async for chunk in input:
- if isinstance(chunk, BaseMessage):
- # extract text
- chunk_content = chunk.content
- if not isinstance(chunk_content, str):
- continue
- chunk = chunk_content
- # pass chunk to parser
- parser.feed(chunk)
- # yield all events
- for event, elem in parser.read_events():
- if event == "start":
- # update current path
- current_path.append(elem.tag)
- current_path_has_children = False
- elif event == "end":
- # remove last element from current path
- current_path.pop()
- # yield element
- if not current_path_has_children:
- yield nested_element(current_path, elem)
- # prevent yielding of parent element
- current_path_has_children = True
- # close parser
- parser.close()
-
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []}
diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock
index 4d9ab6daf26..997d7416722 100644
--- a/libs/core/poetry.lock
+++ b/libs/core/poetry.lock
@@ -660,13 +660,13 @@ files = [
[[package]]
name = "importlib-metadata"
-version = "7.0.2"
+version = "7.1.0"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"},
- {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"},
+ {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"},
+ {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"},
]
[package.dependencies]
@@ -675,17 +675,17 @@ zipp = ">=0.5"
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
-testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "importlib-resources"
-version = "6.3.1"
+version = "6.4.0"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
- {file = "importlib_resources-6.3.1-py3-none-any.whl", hash = "sha256:4811639ca7fa830abdb8e9ca0a104dc6ad13de691d9fe0d3173a71304f068159"},
- {file = "importlib_resources-6.3.1.tar.gz", hash = "sha256:29a3d16556e330c3c8fb8202118c5ff41241cc34cbfb25989bbad226d99b7995"},
+ {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"},
+ {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"},
]
[package.dependencies]
@@ -693,7 +693,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
-testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
+testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
[[package]]
name = "iniconfig"
@@ -1020,13 +1020,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"
[[package]]
name = "jupyter-events"
-version = "0.9.1"
+version = "0.10.0"
description = "Jupyter Event System library"
optional = false
python-versions = ">=3.8"
files = [
- {file = "jupyter_events-0.9.1-py3-none-any.whl", hash = "sha256:e51f43d2c25c2ddf02d7f7a5045f71fc1d5cb5ad04ef6db20da961c077654b9b"},
- {file = "jupyter_events-0.9.1.tar.gz", hash = "sha256:a52e86f59eb317ee71ff2d7500c94b963b8a24f0b7a1517e2e653e24258e15c7"},
+ {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"},
+ {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"},
]
[package.dependencies]
@@ -1216,13 +1216,13 @@ url = "../text-splitters"
[[package]]
name = "langsmith"
-version = "0.1.27"
+version = "0.1.31"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
-python-versions = ">=3.8.1,<4.0"
+python-versions = "<4.0,>=3.8.1"
files = [
- {file = "langsmith-0.1.27-py3-none-any.whl", hash = "sha256:d223176952b1525c958189ab1b894f5bd9891ec9177222f7a978aeee4bf1cc95"},
- {file = "langsmith-0.1.27.tar.gz", hash = "sha256:e0a339d976362051adf3fdbc43fcc7c00bb4615a401321ad7e556bd2dab556c0"},
+ {file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"},
+ {file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"},
]
[package.dependencies]
@@ -1406,13 +1406,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=
[[package]]
name = "nbconvert"
-version = "7.16.2"
+version = "7.16.3"
description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)."
optional = false
python-versions = ">=3.8"
files = [
- {file = "nbconvert-7.16.2-py3-none-any.whl", hash = "sha256:0c01c23981a8de0220255706822c40b751438e32467d6a686e26be08ba784382"},
- {file = "nbconvert-7.16.2.tar.gz", hash = "sha256:8310edd41e1c43947e4ecf16614c61469ebc024898eb808cce0999860fc9fb16"},
+ {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"},
+ {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"},
]
[package.dependencies]
@@ -1439,7 +1439,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp
qtpdf = ["nbconvert[qtpng]"]
qtpng = ["pyqtwebengine (>=5.15)"]
serve = ["tornado (>=6.1)"]
-test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"]
+test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"]
webpdf = ["playwright"]
[[package]]
@@ -1997,17 +1997,17 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy
[[package]]
name = "pytest-mock"
-version = "3.12.0"
+version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
files = [
- {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"},
- {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"},
+ {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
+ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
[package.dependencies]
-pytest = ">=5.0"
+pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
@@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
-content-hash = "ca611429e3dd84ce6dac7ef69d7d9b4da78bf467356946e37016b821e5fe752e"
+content-hash = "a13a0a8454b242106bb681fa74e1f1320a0198f2e07b35d29d985b03a310cf67"
diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml
index 7a0bf312fe0..ca9465bff41 100644
--- a/libs/core/pyproject.toml
+++ b/libs/core/pyproject.toml
@@ -19,6 +19,7 @@ PyYAML = ">=5.3"
requests = "^2"
packaging = "^23.2"
jinja2 = { version = "^3", optional = true }
+defusedxml = "^0.7"
[tool.poetry.group.lint]
optional = true
diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
index 65b095f308e..17ef1a558e6 100644
--- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
+++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -1,4 +1,7 @@
"""Test XMLOutputParser"""
+from typing import AsyncIterator
+from xml.etree.ElementTree import ParseError
+
import pytest
from langchain_core.exceptions import OutputParserException
@@ -40,19 +43,29 @@ More random text
""",
],
)
-def test_xml_output_parser(result: str) -> None:
+async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser()
-
- xml_result = xml_parser.parse(result)
- assert DEF_RESULT_EXPECTED == xml_result
+ assert DEF_RESULT_EXPECTED == xml_parser.parse(result)
+ assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result))
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
+ async def _as_iter(string: str) -> AsyncIterator[str]:
+ for c in string:
+ yield c
+
+ chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
+ assert chunks == [
+ {"foo": [{"bar": [{"baz": None}]}]},
+ {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
+ {"foo": [{"baz": "tag"}]},
+ ]
+
@pytest.mark.parametrize("result", ["foo>", " None:
@@ -63,3 +76,41 @@ def test_xml_output_parser_fail(result: str) -> None:
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)
+
+
+MALICIOUS_XML = """
+
+
+
+
+
+
+
+
+
+
+]>
+&lol9;"""
+
+
+async def tests_billion_laughs_attack() -> None:
+ parser = XMLOutputParser()
+ with pytest.raises(OutputParserException):
+ parser.parse(MALICIOUS_XML)
+
+ with pytest.raises(OutputParserException):
+ await parser.aparse(MALICIOUS_XML)
+
+ with pytest.raises(ParseError):
+ # Right now raises undefined entity error
+ assert list(parser.transform(iter(MALICIOUS_XML))) == [
+ {"foo": [{"bar": [{"baz": None}]}]}
+ ]
+
+ async def _as_iter(string: str) -> AsyncIterator[str]:
+ for c in string:
+ yield c
+
+ with pytest.raises(ParseError):
+ chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))]
+ assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]