From 8bc5cdccee85f385ab79629ad9db31e02bf53b19 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 26 Mar 2024 15:13:09 -0400 Subject: [PATCH] core[patch]: Reverting changes with defusedXML (#19604) DefusedXML is causing parsing errors on previously functional code with the 0.7.x versions. These do not seem to support newer version of python well. 0.8.x has only been released as rc, so we're not going to to use it in the core package --- .../core/langchain_core/output_parsers/xml.py | 36 +++------------ libs/core/poetry.lock | 2 +- libs/core/pyproject.toml | 1 - .../output_parsers/test_xml_parser.py | 44 ++----------------- 4 files changed, 12 insertions(+), 71 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 5a89c5f763c..f74bd4c1050 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,6 @@ import re +import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union -from xml.etree import ElementTree as ET from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage @@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser): return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) def parse(self, text: str) -> Dict[str, List[Any]]: - # Imports are temporarily placed here to avoid issue with caching on CI - # likely if you're reading this you can move them to the top of the file - from defusedxml import ElementTree as DET # type: ignore[import] - # Try to find XML string within triple backticks match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: @@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser): text = text.strip() try: - root = DET.fromstring(text) + root = ET.fromstring(text) return self._root_to_dict(root) - except (DET.ParseError, DET.EntitiesForbidden) as e: + except ET.ParseError as e: msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: - parser = ET.XMLPullParser(["start", "end"]) xml_start_re = re.compile(r"<[a-zA-Z:_]") + parser = ET.XMLPullParser(["start", "end"]) xml_started = False current_path: List[str] = [] current_path_has_children = False @@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser): parser.feed(buffer) buffer = "" # yield all events - for event, elem in parser.read_events(): if event == "start": # update current path @@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser): self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: parser = ET.XMLPullParser(["start", "end"]) - xml_start_re = re.compile(r"<[a-zA-Z:_]") - xml_started = False current_path: List[str] = [] current_path_has_children = False - buffer = "" async for chunk in input: if isinstance(chunk, BaseMessage): # extract text @@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser): if not isinstance(chunk_content, str): continue chunk = chunk_content - # add chunk to buffer of unprocessed text - buffer += chunk - # if xml string hasn't started yet, continue to next chunk - if not xml_started: - if match := xml_start_re.search(buffer): - # if xml string has started, remove all text before it - buffer = buffer[match.start() :] - xml_started = True - else: - continue - # feed buffer to parser - parser.feed(buffer) - buffer = "" + # pass chunk to parser + parser.feed(chunk) # yield all events for event, elem in parser.read_events(): if event == "start": @@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser): if not current_path_has_children: yield nested_element(current_path, elem) # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False + current_path_has_children = True # close parser parser.close() diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index fc812ce7349..9495b9d4a81 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -2966,4 +2966,4 @@ extended-testing = ["jinja2"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2f61e22c118e13c40a1b7980afe06a37a6349ee239c948b9c49e8b1dc06facc1" +content-hash = "203d96b330412ce9defad6739381e4031fc9e995c2d9e0a61a905fc79fff11dd" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index daa6b9bdb87..7f476fca7ff 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -18,7 +18,6 @@ PyYAML = ">=5.3" requests = "^2" packaging = "^23.2" jinja2 = { version = "^3", optional = true } -defusedxml = "^0.7" [tool.poetry.group.lint] optional = true diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 7ba68f42a4d..65b095f308e 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,6 +1,4 @@ """Test XMLOutputParser""" -from typing import AsyncIterator - import pytest from langchain_core.exceptions import OutputParserException @@ -42,29 +40,19 @@ More random text """, ], ) -async def test_xml_output_parser(result: str) -> None: +def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser.""" xml_parser = XMLOutputParser() - assert DEF_RESULT_EXPECTED == xml_parser.parse(result) - assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result)) + + xml_result = xml_parser.parse(result) + assert DEF_RESULT_EXPECTED == xml_result assert list(xml_parser.transform(iter(result))) == [ {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]}, ] - async def _as_iter(string: str) -> AsyncIterator[str]: - for c in string: - yield c - - chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] - assert chunks == [ - {"foo": [{"bar": [{"baz": None}]}]}, - {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, - {"foo": [{"baz": "tag"}]}, - ] - @pytest.mark.parametrize("result", ["foo>", " None: @@ -75,27 +63,3 @@ def test_xml_output_parser_fail(result: str) -> None: with pytest.raises(OutputParserException) as e: xml_parser.parse(result) assert "Failed to parse" in str(e) - - -MALICIOUS_XML = """ - - - - - - - - - - -]> -&lol9;""" - - -async def tests_billion_laughs_attack() -> None: - parser = XMLOutputParser() - with pytest.raises(OutputParserException): - parser.parse(MALICIOUS_XML) - - with pytest.raises(OutputParserException): - await parser.aparse(MALICIOUS_XML)