Compare commits

...

2 Commits

Author SHA1 Message Date
Eugene Yurtsev
94f8a15b3e x 2024-03-26 13:56:29 -04:00
Eugene Yurtsev
4e36ca782c x 2024-03-26 13:18:19 -04:00

View File

@@ -1,6 +1,6 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
@@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
@@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip()
try:
root = DET.fromstring(text)
root = ET.fromstring(text)
return self._root_to_dict(root)
except (DET.ParseError, DET.EntitiesForbidden) as e:
except ET.ParseError as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
@@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer)
buffer = ""
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
@@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser):
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
@@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
@@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
if current_path:
current_path_has_children = True
else:
xml_started = False
current_path_has_children = True
# close parser
parser.close()