core[patch]: XMLOutputParser fix to handle changes to xml standard library (#19612)

Newest python micro releases broke streaming in the XMLOutputParser. This fixes the parsing code to work with trailing junk after the XML content.
This commit is contained in:
Eugene Yurtsev 2024-03-27 09:25:28 -04:00 committed by GitHub
parent 3a7d2cf443
commit 8ab7bb3166
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 127 additions and 55 deletions

View File

@ -1,4 +1,5 @@
import re import re
import xml
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
@ -81,33 +82,52 @@ class XMLOutputParser(BaseTransformOutputParser):
continue continue
# feed buffer to parser # feed buffer to parser
parser.feed(buffer) parser.feed(buffer)
buffer = "" buffer = ""
# yield all events # yield all events
for event, elem in parser.read_events(): try:
if event == "start": for event, elem in parser.read_events():
# update current path if event == "start":
current_path.append(elem.tag) # update current path
current_path_has_children = False current_path.append(elem.tag)
elif event == "end": current_path_has_children = False
# remove last element from current path elif event == "end":
current_path.pop() # remove last element from current path
# yield element #
if not current_path_has_children: current_path.pop()
yield nested_element(current_path, elem) # yield element
# prevent yielding of parent element if not current_path_has_children:
if current_path: yield nested_element(current_path, elem)
current_path_has_children = True # prevent yielding of parent element
else: if current_path:
xml_started = False current_path_has_children = True
else:
xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path:
# If it is empty, we can ignore this error.
break
else:
raise
# close parser # close parser
parser.close() try:
parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
async def _atransform( async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"]) parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = [] current_path: List[str] = []
current_path_has_children = False current_path_has_children = False
buffer = ""
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
# extract text # extract text
@ -115,24 +135,54 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str): if not isinstance(chunk_content, str):
continue continue
chunk = chunk_content chunk = chunk_content
# pass chunk to parser # add chunk to buffer of unprocessed text
parser.feed(chunk) buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events # yield all events
for event, elem in parser.read_events(): try:
if event == "start": for event, elem in parser.read_events():
# update current path if event == "start":
current_path.append(elem.tag) # update current path
current_path_has_children = False current_path.append(elem.tag)
elif event == "end": current_path_has_children = False
# remove last element from current path elif event == "end":
current_path.pop() # remove last element from current path
# yield element #
if not current_path_has_children: current_path.pop()
yield nested_element(current_path, elem) # yield element
# prevent yielding of parent element if not current_path_has_children:
current_path_has_children = True yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
except xml.etree.ElementTree.ParseError:
# This might be junk at the end of the XML input.
# Let's check whether the current path is empty.
if not current_path:
# If it is empty, we can ignore this error.
break
else:
raise
# close parser # close parser
parser.close() try:
parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary.""" """Converts xml tree to python dictionary."""

View File

@ -1,10 +1,12 @@
"""Test XMLOutputParser""" """Test XMLOutputParser"""
from typing import AsyncIterator, Iterable
import pytest import pytest
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.output_parsers.xml import XMLOutputParser
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?> DATA = """
<foo> <foo>
<bar> <bar>
<baz></baz> <baz></baz>
@ -13,6 +15,25 @@ DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
<baz>tag</baz> <baz>tag</baz>
</foo>""" </foo>"""
WITH_XML_HEADER = f"""<?xml version="1.0" encoding="UTF-8"?>
{DATA}"""
IN_XML_TAGS_WITH_XML_HEADER = f"""
```xml
{WITH_XML_HEADER}
```
"""
IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK = f"""
Some random text
```xml
{WITH_XML_HEADER}
```
More random text
"""
DEF_RESULT_EXPECTED = { DEF_RESULT_EXPECTED = {
"foo": [ "foo": [
{"bar": [{"baz": None}, {"baz": "slim.shady"}]}, {"bar": [{"baz": None}, {"baz": "slim.shady"}]},
@ -24,23 +45,13 @@ DEF_RESULT_EXPECTED = {
@pytest.mark.parametrize( @pytest.mark.parametrize(
"result", "result",
[ [
DEF_RESULT_ENCODING, DATA, # has no xml header
DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :], WITH_XML_HEADER,
f""" IN_XML_TAGS_WITH_XML_HEADER,
```xml IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
{DEF_RESULT_ENCODING}
```
""",
f"""
Some random text
```xml
{DEF_RESULT_ENCODING}
```
More random text
""",
], ],
) )
def test_xml_output_parser(result: str) -> None: async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser.""" """Test XMLOutputParser."""
xml_parser = XMLOutputParser() xml_parser = XMLOutputParser()
@ -48,12 +59,23 @@ def test_xml_output_parser(result: str) -> None:
xml_result = xml_parser.parse(result) xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result assert DEF_RESULT_EXPECTED == xml_result
# TODO(Eugene): Fix this test for newer python version assert list(xml_parser.transform(iter(result))) == [
# assert list(xml_parser.transform(iter(result))) == [ {"foo": [{"bar": [{"baz": None}]}]},
# {"foo": [{"bar": [{"baz": None}]}]}, {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
# {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, {"foo": [{"baz": "tag"}]},
# {"foo": [{"baz": "tag"}]}, ]
# ]
async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
for item in iterable:
yield item
chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
assert list(chunks) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"]) @pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])