Compare commits

...

7 Commits

Author SHA1 Message Date
Mason Daugherty
d34c1052d9 . 2025-07-06 23:10:42 -04:00
Mason Daugherty
f7dcc834a8 . 2025-07-06 23:09:45 -04:00
Mason Daugherty
2a2c9cd5ba refactor: streamline XML formatting and parsing functions for improved security and clarity 2025-07-06 23:08:17 -04:00
Mason Daugherty
87f00e84ad sort 2025-07-06 22:40:07 -04:00
Mason Daugherty
57fd8887d8 lint 2025-07-06 22:17:37 -04:00
Eugene Yurtsev
18824b5761 x 2025-07-03 18:38:36 -04:00
Eugene Yurtsev
5d6f03cc34 x 2025-07-03 18:34:25 -04:00
4 changed files with 225 additions and 87 deletions

View File

@@ -1,21 +1,41 @@
import xml.sax.saxutils
from langchain_core.agents import AgentAction
def format_xml(
intermediate_steps: list[tuple[AgentAction, str]],
*,
escape_xml: bool = True,
) -> str:
"""Format the intermediate steps as XML.
Escapes all special XML characters in the content, preventing injection of malicious
or malformed XML.
Args:
intermediate_steps: The intermediate steps.
intermediate_steps: The intermediate steps, each a tuple of
(AgentAction, observation).
escape_xml: If True, all XML special characters in the tool name,
tool input, and observation will be escaped (e.g., ``<`` becomes ``&lt;``).
Returns:
The intermediate steps as XML.
A string of concatenated XML blocks representing the intermediate steps.
"""
log = ""
for action, observation in intermediate_steps:
tool = str(action.tool)
tool_input = str(action.tool_input)
observation_str = str(observation)
if escape_xml:
entities = {"'": "&apos;", '"': "&quot;"}
tool = xml.sax.saxutils.escape(tool, entities)
tool_input = xml.sax.saxutils.escape(tool_input, entities)
observation_str = xml.sax.saxutils.escape(observation_str, entities)
log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
f"</tool_input><observation>{observation}</observation>"
f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation_str}</observation>"
)
return log

View File

@@ -1,48 +1,115 @@
import re
import xml.sax.saxutils
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field
from langchain.agents import AgentOutputParser
class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers in XML format.
"""Parses tool invocations and final answers from XML-formatted agent output.
Expects output to be in one of two formats.
This parser is hardened against XML injection by using standard XML entity
decoding for content within tags. It is designed to work with the corresponding
``format_xml`` function.
If the output signals that an action should be taken,
should be in the below format. This will result in an AgentAction
being returned.
Args:
unescape_xml: If True, the parser will unescape XML special characters in the
content of tags. This should be enabled if the agent's output was formatted
with XML escaping.
```
<tool>search</tool>
<tool_input>what is 2 + 2</tool_input>
```
If False, the parser will return the raw content as is, which may include
XML special characters like `<`, `>`, and `&`.
If the output signals that a final answer should be given,
should be in the below format. This will result in an AgentFinish
being returned.
Expected formats:
Tool invocation (returns AgentAction):
<tool>search</tool>
<tool_input>what is 2 + 2</tool_input>
```
<final_answer>Foo</final_answer>
```
Final answer (returns AgentFinish):
<final_answer>The answer is 4</final_answer>
Raises:
ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure.
"""
unescape_xml: bool = Field(default=True)
"""If True, the parser will unescape XML special characters in the content
of tags. This should be enabled if the agent's output was formatted
with XML escaping.
If False, the parser will return the raw content as is,
which may include XML special characters like `<`, `>`, and `&`.
"""
def _extract_tag_content(
self, tag: str, text: str, *, required: bool
) -> Union[str, None]:
"""
Extracts content from a specified XML tag, ensuring it appears at most once.
Args:
tag: The name of the XML tag (e.g., ``'tool'``).
text: The text to parse.
required: If True, a ValueError will be raised if the tag is not found.
Returns:
The unescaped content of the tag as a string, or None if not found
and not required.
Raises:
ValueError: If the tag appears more than once, or if it is required
but not found.
"""
pattern = f"<{tag}>(.*?)</{tag}>"
matches = re.findall(pattern, text, re.DOTALL)
if len(matches) > 1:
raise ValueError(
f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
)
if not matches:
if required:
raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
return None
content = matches[0]
if self.unescape_xml:
entities = {"&apos;": "'", "&quot;": '"'}
return xml.sax.saxutils.unescape(content, entities)
return content
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text:
tool, tool_input = text.split("</tool>")
_tool = tool.split("<tool>")[1]
_tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0]
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text:
_, answer = text.split("<final_answer>")
if "</final_answer>" in answer:
answer = answer.split("</final_answer>")[0]
"""
Parses the given text into an AgentAction or AgentFinish object.
"""
# Check for a tool invocation
if "<tool>" in text and "</tool>" in text:
tool = self._extract_tag_content("tool", text, required=True)
if tool is None:
raise ValueError("Tool content should not be None when required=True")
# Tool input is optional
tool_input = (
self._extract_tag_content("tool_input", text, required=False) or ""
)
return AgentAction(tool=tool, tool_input=tool_input, log=text)
# Check for a final answer
elif "<final_answer>" in text and "</final_answer>" in text:
answer = self._extract_tag_content("final_answer", text, required=True)
return AgentFinish(return_values={"output": answer}, log=text)
# If neither format is found, raise an error
else:
raise ValueError
raise ValueError(
"Could not parse LLM output. Expected a tool invocation with <tool> "
"and <tool_input> tags, or a final answer with <final_answer> tags."
)
def get_format_instructions(self) -> str:
raise NotImplementedError

View File

@@ -3,39 +3,32 @@ from langchain_core.agents import AgentAction
from langchain.agents.format_scratchpad.xml import format_xml
def test_single_agent_action_observation() -> None:
# Arrange
def test_format_single_action() -> None:
"""Tests formatting of a single agent action and observation."""
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
observation = "Observation1"
intermediate_steps = [(agent_action, "Observation1")]
result = format_xml(intermediate_steps, escape_xml=False)
expected = (
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
"<observation>Observation1</observation>"
)
assert result == expected
def test_format_xml_escaping() -> None:
"""Tests that XML special characters in content are properly escaped."""
agent_action = AgentAction(
tool="<tool> & 'some_tool'", tool_input='<query> with "quotes"', log=""
)
observation = "Observed > 5 items"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps)
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation>"""
# Assert
assert result == expected_result
result = format_xml(intermediate_steps, escape_xml=True)
def test_multiple_agent_actions_observations() -> None:
# Arrange
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
observation1 = "Observation1"
observation2 = "Observation2"
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
# Act
result = format_xml(intermediate_steps)
# Assert
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation><tool>\
Tool2</tool><tool_input>Input2</tool_input><observation>\
Observation2</observation>"""
assert result == expected_result
def test_empty_list_agent_actions() -> None:
result = format_xml([])
assert result == ""
expected = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
"<observation>Observed &gt; 5 items</observation>"
)
assert result == expected

View File

@@ -1,34 +1,92 @@
import pytest
from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents.format_scratchpad.xml import format_xml
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
def test_tool_usage() -> None:
def test_parser_tool_usage() -> None:
"""Tests parsing a standard tool invocation."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<tool>search</tool><tool_input>foo</tool_input>"
output = parser.invoke(_input)
expected = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected
def test_parser_final_answer() -> None:
"""Tests parsing a standard final answer."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<final_answer>bar</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected
def test_parser_with_escaped_content() -> None:
"""Tests that the parser correctly unescapes standard XML entities."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
)
output = parser.invoke(_input)
expected = AgentAction(
tool="<tool> & 'some_tool'",
tool_input='<query> with "quotes"',
log=_input,
)
assert output == expected
def test_parser_final_answer_escaped() -> None:
"""Tests parsing a final answer with escaped content."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = "<final_answer>The answer is &gt; 42.</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
assert output == expected
def test_parser_error_on_multiple_tool_tags() -> None:
"""Tests that the parser raises an error for multiple tool tags."""
parser = XMLAgentOutputParser()
# Test when final closing </tool_input> is included
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
# Test when final closing </tool_input> is NOT included
# This happens when it's used as a stop token
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
_input = "<tool>tool1</tool><tool>tool2</tool><tool_input>input</tool_input>"
with pytest.raises(ValueError, match="Found 2 <tool> blocks"):
parser.invoke(_input)
def test_finish() -> None:
def test_parser_error_on_missing_required_tag() -> None:
"""Tests that the parser raises an error if a required tag is missing."""
parser = XMLAgentOutputParser()
# Test when final closing <final_answer> is included
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
_input = "<tool_input>some_input</tool_input>"
# Test when final closing <final_answer> is NOT included
# This happens when it's used as a stop token
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
with pytest.raises(ValueError, match="Could not parse LLM output"):
parser.invoke(_input)
def test_integration_format_and_parse() -> None:
"""An integration test to ensure formatting and parsing work together."""
parser = XMLAgentOutputParser()
agent_action = AgentAction(tool="<special-tool>", tool_input="input with &", log="")
intermediate_steps = [(agent_action, "observation with >")]
# 1. Format the data, escaping the XML
formatted_xml = format_xml(intermediate_steps, escape_xml=True)
# Extract the tool call part for parsing
tool_part = formatted_xml.split("<observation>")[0]
# 2. Parse the formatted data
output = parser.invoke(tool_part)
# 3. Assert that the original data is recovered
expected_action = AgentAction(
tool="<special-tool>", tool_input="input with &", log=tool_part
)
assert output == expected_action