diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py
index 5a89c5f763c..f74bd4c1050 100644
--- a/libs/core/langchain_core/output_parsers/xml.py
+++ b/libs/core/langchain_core/output_parsers/xml.py
@@ -1,6 +1,6 @@
import re
+import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
-from xml.etree import ElementTree as ET
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
- # Imports are temporarily placed here to avoid issue with caching on CI
- # likely if you're reading this you can move them to the top of the file
- from defusedxml import ElementTree as DET # type: ignore[import]
-
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
@@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip()
try:
- root = DET.fromstring(text)
+ root = ET.fromstring(text)
return self._root_to_dict(root)
- except (DET.ParseError, DET.EntitiesForbidden) as e:
+ except ET.ParseError as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
- parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
+ parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
@@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer)
buffer = ""
# yield all events
-
for event, elem in parser.read_events():
if event == "start":
# update current path
@@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser):
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
- xml_start_re = re.compile(r"<[a-zA-Z:_]")
- xml_started = False
current_path: List[str] = []
current_path_has_children = False
- buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
@@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
- # add chunk to buffer of unprocessed text
- buffer += chunk
- # if xml string hasn't started yet, continue to next chunk
- if not xml_started:
- if match := xml_start_re.search(buffer):
- # if xml string has started, remove all text before it
- buffer = buffer[match.start() :]
- xml_started = True
- else:
- continue
- # feed buffer to parser
- parser.feed(buffer)
- buffer = ""
+ # pass chunk to parser
+ parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
@@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
- if current_path:
- current_path_has_children = True
- else:
- xml_started = False
+ current_path_has_children = True
# close parser
parser.close()
diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock
index fc812ce7349..9495b9d4a81 100644
--- a/libs/core/poetry.lock
+++ b/libs/core/poetry.lock
@@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
-content-hash = "2f61e22c118e13c40a1b7980afe06a37a6349ee239c948b9c49e8b1dc06facc1"
+content-hash = "203d96b330412ce9defad6739381e4031fc9e995c2d9e0a61a905fc79fff11dd"
diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml
index daa6b9bdb87..7f476fca7ff 100644
--- a/libs/core/pyproject.toml
+++ b/libs/core/pyproject.toml
@@ -18,7 +18,6 @@ PyYAML = ">=5.3"
requests = "^2"
packaging = "^23.2"
jinja2 = { version = "^3", optional = true }
-defusedxml = "^0.7"
[tool.poetry.group.lint]
optional = true
diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
index 7ba68f42a4d..65b095f308e 100644
--- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
+++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -1,6 +1,4 @@
"""Test XMLOutputParser"""
-from typing import AsyncIterator
-
import pytest
from langchain_core.exceptions import OutputParserException
@@ -42,29 +40,19 @@ More random text
""",
],
)
-async def test_xml_output_parser(result: str) -> None:
+def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser()
- assert DEF_RESULT_EXPECTED == xml_parser.parse(result)
- assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result))
+
+ xml_result = xml_parser.parse(result)
+ assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
- async def _as_iter(string: str) -> AsyncIterator[str]:
- for c in string:
- yield c
-
- chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
- assert chunks == [
- {"foo": [{"bar": [{"baz": None}]}]},
- {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
- {"foo": [{"baz": "tag"}]},
- ]
-
@pytest.mark.parametrize("result", ["foo>", " None:
@@ -75,27 +63,3 @@ def test_xml_output_parser_fail(result: str) -> None:
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)
-
-
-MALICIOUS_XML = """
-
-
-
-
-
-
-
-
-
-
-]>
-&lol9;"""
-
-
-async def tests_billion_laughs_attack() -> None:
- parser = XMLOutputParser()
- with pytest.raises(OutputParserException):
- parser.parse(MALICIOUS_XML)
-
- with pytest.raises(OutputParserException):
- await parser.aparse(MALICIOUS_XML)