Strip code block fences and extra test from xml when doing streaming … (#15293)

…parse

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos
2023-12-28 16:37:15 -08:00
committed by GitHub
parent ec72225265
commit 36ceffd2cd
2 changed files with 36 additions and 4 deletions

View File

@@ -55,9 +55,12 @@ class XMLOutputParser(BaseTransformOutputParser):
def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
@@ -65,8 +68,19 @@ class XMLOutputParser(BaseTransformOutputParser):
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events
for event, elem in parser.read_events():
if event == "start":
@@ -80,7 +94,10 @@ class XMLOutputParser(BaseTransformOutputParser):
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
if current_path:
current_path_has_children = True
else:
xml_started = False
# close parser
parser.close()

View File

@@ -22,7 +22,22 @@ DEF_RESULT_EXPECTED = {
@pytest.mark.parametrize(
"result",
[DEF_RESULT_ENCODING, DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :]],
[
DEF_RESULT_ENCODING,
DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :],
f"""
```xml
{DEF_RESULT_ENCODING}
```
""",
f"""
Some random text
```xml
{DEF_RESULT_ENCODING}
```
More random text
""",
],
)
def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""