Add XMLOutputParser (#10051)

**Description**
Adds new output parser, this time enabling the output of LLM to be of an
XML format. Seems to be particularly useful together with Claude model.
Addresses [issue
9820](https://github.com/langchain-ai/langchain/issues/9820).

**Twitter handle**
@deepsense_ai @matt_wosinski
This commit is contained in:
Mateusz Wosinski
2023-09-20 01:17:33 +02:00
committed by GitHub
parent d6df288380
commit 720f6dbaac
5 changed files with 465 additions and 0 deletions

View File

@@ -28,6 +28,7 @@ from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.regex_dict import RegexDictParser
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
from langchain.output_parsers.xml import XMLOutputParser
__all__ = [
"BooleanOutputParser",
@@ -46,4 +47,5 @@ __all__ = [
"RetryOutputParser",
"RetryWithErrorOutputParser",
"StructuredOutputParser",
"XMLOutputParser",
]

View File

@@ -25,3 +25,19 @@ Here is the output schema:
```
{schema}
```"""
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
1. Output should conform to the tags below.
2. If tags are not given, make them on your own.
3. Remember to always open and close all the tags.
As an example, for the tags ["foo", "bar", "baz"]:
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema.
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance.
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance.
Here are the output tags:
```
{tags}
```"""

View File

@@ -0,0 +1,45 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser
class XMLOutputParser(BaseOutputParser):
"""Parse an output using xml format."""
tags: Optional[List[str]] = None
encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
)
def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
text = text.strip("`").strip("xml")
encoding_match = self.encoding_matcher.search(text)
if encoding_match:
text = encoding_match.group(2)
if (text.startswith("<") or text.startswith("\n<")) and (
text.endswith(">") or text.endswith(">\n")
):
root = ET.fromstring(text)
return self._root_to_dict(root)
else:
raise ValueError(f"Could not parse output: {text}")
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []}
for child in root:
if len(child) == 0:
result[root.tag].append({child.tag: child.text})
else:
result[root.tag].append(self._root_to_dict(child))
return result
@property
def _type(self) -> str:
return "xml"

View File

@@ -0,0 +1,44 @@
"""Test XMLOutputParser"""
import pytest
from langchain.output_parsers.xml import XMLOutputParser
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
<foo>
<bar>
<baz></baz>
<baz>slim.shady</baz>
</bar>
<baz>tag</baz>
</foo>"""
DEF_RESULT_EXPECTED = {
"foo": [
{"bar": [{"baz": None}, {"baz": "slim.shady"}]},
{"baz": "tag"},
],
}
@pytest.mark.parametrize(
"result",
[DEF_RESULT_ENCODING, DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :]],
)
def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser()
xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None:
"""Test XMLOutputParser where complete output is not in XML format."""
xml_parser = XMLOutputParser()
with pytest.raises(ValueError) as e:
xml_parser.parse(result)
assert "Could not parse output" in str(e)