mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 15:33:11 +00:00
Add XMLOutputParser (#10051)
**Description** Adds new output parser, this time enabling the output of LLM to be of an XML format. Seems to be particularly useful together with Claude model. Addresses [issue 9820](https://github.com/langchain-ai/langchain/issues/9820). **Twitter handle** @deepsense_ai @matt_wosinski
This commit is contained in:
@@ -28,6 +28,7 @@ from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.output_parsers.regex_dict import RegexDictParser
|
||||
from langchain.output_parsers.retry import RetryOutputParser, RetryWithErrorOutputParser
|
||||
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
|
||||
from langchain.output_parsers.xml import XMLOutputParser
|
||||
|
||||
__all__ = [
|
||||
"BooleanOutputParser",
|
||||
@@ -46,4 +47,5 @@ __all__ = [
|
||||
"RetryOutputParser",
|
||||
"RetryWithErrorOutputParser",
|
||||
"StructuredOutputParser",
|
||||
"XMLOutputParser",
|
||||
]
|
||||
|
@@ -25,3 +25,19 @@ Here is the output schema:
|
||||
```
|
||||
{schema}
|
||||
```"""
|
||||
|
||||
|
||||
XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
|
||||
1. Output should conform to the tags below.
|
||||
2. If tags are not given, make them on your own.
|
||||
3. Remember to always open and close all the tags.
|
||||
|
||||
As an example, for the tags ["foo", "bar", "baz"]:
|
||||
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema.
|
||||
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance.
|
||||
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance.
|
||||
|
||||
Here are the output tags:
|
||||
```
|
||||
{tags}
|
||||
```"""
|
||||
|
45
libs/langchain/langchain/output_parsers/xml.py
Normal file
45
libs/langchain/langchain/output_parsers/xml.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.output_parsers.format_instructions import XML_FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class XMLOutputParser(BaseOutputParser):
|
||||
"""Parse an output using xml format."""
|
||||
|
||||
tags: Optional[List[str]] = None
|
||||
encoding_matcher: re.Pattern = re.compile(
|
||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, List[Any]]:
|
||||
text = text.strip("`").strip("xml")
|
||||
encoding_match = self.encoding_matcher.search(text)
|
||||
if encoding_match:
|
||||
text = encoding_match.group(2)
|
||||
if (text.startswith("<") or text.startswith("\n<")) and (
|
||||
text.endswith(">") or text.endswith(">\n")
|
||||
):
|
||||
root = ET.fromstring(text)
|
||||
return self._root_to_dict(root)
|
||||
else:
|
||||
raise ValueError(f"Could not parse output: {text}")
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
|
||||
"""Converts xml tree to python dictionary."""
|
||||
result: Dict[str, List[Any]] = {root.tag: []}
|
||||
for child in root:
|
||||
if len(child) == 0:
|
||||
result[root.tag].append({child.tag: child.text})
|
||||
else:
|
||||
result[root.tag].append(self._root_to_dict(child))
|
||||
return result
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "xml"
|
@@ -0,0 +1,44 @@
|
||||
"""Test XMLOutputParser"""
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.xml import XMLOutputParser
|
||||
|
||||
DEF_RESULT_ENCODING = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<foo>
|
||||
<bar>
|
||||
<baz></baz>
|
||||
<baz>slim.shady</baz>
|
||||
</bar>
|
||||
<baz>tag</baz>
|
||||
</foo>"""
|
||||
|
||||
DEF_RESULT_EXPECTED = {
|
||||
"foo": [
|
||||
{"bar": [{"baz": None}, {"baz": "slim.shady"}]},
|
||||
{"baz": "tag"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"result",
|
||||
[DEF_RESULT_ENCODING, DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :]],
|
||||
)
|
||||
def test_xml_output_parser(result: str) -> None:
|
||||
"""Test XMLOutputParser."""
|
||||
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
xml_result = xml_parser.parse(result)
|
||||
assert DEF_RESULT_EXPECTED == xml_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
|
||||
def test_xml_output_parser_fail(result: str) -> None:
|
||||
"""Test XMLOutputParser where complete output is not in XML format."""
|
||||
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
xml_parser.parse(result)
|
||||
assert "Could not parse output" in str(e)
|
Reference in New Issue
Block a user