core[patch]: fix xml output parser transform (#19530)

Previous PR passed _parser attribute which apparently is not meant to be
used by user code and causes non deterministic failures on CI when
testing the transform and a transform methods. Reverting this change
temporarily.
This commit is contained in:
Eugene Yurtsev 2024-03-25 17:34:45 -04:00 committed by GitHub
parent e6952b04d5
commit 56f4c5459b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 29 deletions

View File

@ -1,7 +1,6 @@
import re import re
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -61,13 +60,7 @@ class XMLOutputParser(BaseTransformOutputParser):
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
# Imports are temporarily placed here to avoid issue with caching on CI parser = ET.XMLPullParser(["start", "end"])
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
parser = ET.XMLPullParser(
["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder())
)
xml_start_re = re.compile(r"<[a-zA-Z:_]") xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False xml_started = False
current_path: List[str] = [] current_path: List[str] = []
@ -117,12 +110,7 @@ class XMLOutputParser(BaseTransformOutputParser):
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
# Imports are temporarily placed here to avoid issue with caching on CI parser = ET.XMLPullParser(["start", "end"])
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]
_parser = DefusedXMLParser(target=TreeBuilder())
parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
xml_start_re = re.compile(r"<[a-zA-Z:_]") xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False xml_started = False
current_path: List[str] = [] current_path: List[str] = []

View File

@ -1,6 +1,5 @@
"""Test XMLOutputParser""" """Test XMLOutputParser"""
from typing import AsyncIterator from typing import AsyncIterator
from xml.etree.ElementTree import ParseError
import pytest import pytest
@ -100,17 +99,3 @@ async def tests_billion_laughs_attack() -> None:
with pytest.raises(OutputParserException): with pytest.raises(OutputParserException):
await parser.aparse(MALICIOUS_XML) await parser.aparse(MALICIOUS_XML)
with pytest.raises(ParseError):
# Right now raises undefined entity error
assert list(parser.transform(iter(MALICIOUS_XML))) == [
{"foo": [{"bar": [{"baz": None}]}]}
]
async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c
with pytest.raises(ParseError):
chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))]
assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]