Compare commits

...

2 Commits

Author SHA1 Message Date
Eugene Yurtsev
94f8a15b3e x 2024-03-26 13:56:29 -04:00
Eugene Yurtsev
4e36ca782c x 2024-03-26 13:18:19 -04:00

View File

@@ -1,6 +1,6 @@
import re import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@@ -35,10 +35,6 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]: def parse(self, text: str) -> Dict[str, List[Any]]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]
# Try to find XML string within triple backticks # Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None: if match is not None:
@@ -50,18 +46,18 @@ class XMLOutputParser(BaseTransformOutputParser):
text = text.strip() text = text.strip()
try: try:
root = DET.fromstring(text) root = ET.fromstring(text)
return self._root_to_dict(root) return self._root_to_dict(root)
except (DET.ParseError, DET.EntitiesForbidden) as e: except ET.ParseError as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}" msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e raise OutputParserException(msg, llm_output=text) from e
def _transform( def _transform(
self, input: Iterator[Union[str, BaseMessage]] self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]") xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False xml_started = False
current_path: List[str] = [] current_path: List[str] = []
current_path_has_children = False current_path_has_children = False
@@ -87,7 +83,6 @@ class XMLOutputParser(BaseTransformOutputParser):
parser.feed(buffer) parser.feed(buffer)
buffer = "" buffer = ""
# yield all events # yield all events
for event, elem in parser.read_events(): for event, elem in parser.read_events():
if event == "start": if event == "start":
# update current path # update current path
@@ -111,11 +106,8 @@ class XMLOutputParser(BaseTransformOutputParser):
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"]) parser = ET.XMLPullParser(["start", "end"])
xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False
current_path: List[str] = [] current_path: List[str] = []
current_path_has_children = False current_path_has_children = False
buffer = ""
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
# extract text # extract text
@@ -123,19 +115,8 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str): if not isinstance(chunk_content, str):
continue continue
chunk = chunk_content chunk = chunk_content
# add chunk to buffer of unprocessed text # pass chunk to parser
buffer += chunk parser.feed(chunk)
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events # yield all events
for event, elem in parser.read_events(): for event, elem in parser.read_events():
if event == "start": if event == "start":
@@ -149,10 +130,7 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children: if not current_path_has_children:
yield nested_element(current_path, elem) yield nested_element(current_path, elem)
# prevent yielding of parent element # prevent yielding of parent element
if current_path:
current_path_has_children = True current_path_has_children = True
else:
xml_started = False
# close parser # close parser
parser.close() parser.close()