diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index c0e3e72baf4..f74bd4c1050 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -2,6 +2,7 @@ import re import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.runnables.utils import AddableDict @@ -44,13 +45,13 @@ class XMLOutputParser(BaseTransformOutputParser): text = encoding_match.group(2) text = text.strip() - if (text.startswith("<") or text.startswith("\n<")) and ( - text.endswith(">") or text.endswith(">\n") - ): + try: root = ET.fromstring(text) return self._root_to_dict(root) - else: - raise ValueError(f"Could not parse output: {text}") + + except ET.ParseError as e: + msg = f"Failed to parse XML format from completion {text}. Got: {e}" + raise OutputParserException(msg, llm_output=text) from e def _transform( self, input: Iterator[Union[str, BaseMessage]] diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 697f4e4776e..65b095f308e 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,6 +1,7 @@ """Test XMLOutputParser""" import pytest +from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.xml import XMLOutputParser DEF_RESULT_ENCODING = """ @@ -59,6 +60,6 @@ def test_xml_output_parser_fail(result: str) -> None: xml_parser = XMLOutputParser() - with pytest.raises(ValueError) as e: + with pytest.raises(OutputParserException) as e: xml_parser.parse(result) - assert "Could not parse output" in str(e) + assert "Failed to parse" in str(e)