Implement streaming for xml output parser (#14984)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-21 11:30:18 -08:00 committed by GitHub
parent 94bc3967a1
commit b471166df7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 3 deletions

View File

@ -1,13 +1,15 @@
import re import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core.output_parsers import BaseOutputParser from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
class XMLOutputParser(BaseOutputParser): class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format.""" """Parse an output using xml format."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
@ -33,6 +35,70 @@ class XMLOutputParser(BaseOutputParser):
else: else:
raise ValueError(f"Could not parse output: {text}") raise ValueError(f"Could not parse output: {text}")
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary.""" """Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []} result: Dict[str, List[Any]] = {root.tag: []}
@ -46,3 +112,11 @@ class XMLOutputParser(BaseOutputParser):
@property @property
def _type(self) -> str: def _type(self) -> str:
return "xml" return "xml"
def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path."""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else:
return AddableDict({path[0]: [nested_element(path[1:], elem)]})

View File

@ -31,6 +31,11 @@ def test_xml_output_parser(result: str) -> None:
xml_result = xml_parser.parse(result) xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result assert DEF_RESULT_EXPECTED == xml_result
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"]) @pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])