From 36ceffd2cd7040e23513f5b6797274608c2ef872 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Dec 2023 16:37:15 -0800 Subject: [PATCH] =?UTF-8?q?Strip=20code=20block=20fences=20and=20extra=20t?= =?UTF-8?q?est=20from=20xml=20when=20doing=20streaming=20=E2=80=A6=20(#152?= =?UTF-8?q?93)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …parse --- .../core/langchain_core/output_parsers/xml.py | 23 ++++++++++++++++--- .../output_parsers/test_xml_parser.py | 17 +++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 9a93023ce12..c0e3e72baf4 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -55,9 +55,12 @@ class XMLOutputParser(BaseTransformOutputParser): def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: + xml_start_re = re.compile(r"<[a-zA-Z:_]") parser = ET.XMLPullParser(["start", "end"]) + xml_started = False current_path: List[str] = [] current_path_has_children = False + buffer = "" for chunk in input: if isinstance(chunk, BaseMessage): # extract text @@ -65,8 +68,19 @@ class XMLOutputParser(BaseTransformOutputParser): if not isinstance(chunk_content, str): continue chunk = chunk_content - # pass chunk to parser - parser.feed(chunk) + # add chunk to buffer of unprocessed text + buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not xml_started: + if match := xml_start_re.search(buffer): + # if xml string has started, remove all text before it + buffer = buffer[match.start() :] + xml_started = True + else: + continue + # feed buffer to parser + parser.feed(buffer) + buffer = "" # yield all events for event, elem in parser.read_events(): if event == "start": @@ -80,7 +94,10 @@ class XMLOutputParser(BaseTransformOutputParser): if not current_path_has_children: yield nested_element(current_path, elem) # prevent yielding of parent element - current_path_has_children = True + if current_path: + current_path_has_children = True + else: + xml_started = False # close parser parser.close() diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index fb92e96331a..697f4e4776e 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -22,7 +22,22 @@ DEF_RESULT_EXPECTED = { @pytest.mark.parametrize( "result", - [DEF_RESULT_ENCODING, DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :]], + [ + DEF_RESULT_ENCODING, + DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :], + f""" +```xml +{DEF_RESULT_ENCODING} +``` +""", + f""" +Some random text +```xml +{DEF_RESULT_ENCODING} +``` +More random text +""", + ], ) def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser."""