refactor: streamline XML formatting and parsing functions for improved security and clarity

This commit is contained in:
Mason Daugherty 2025-07-06 23:08:17 -04:00
parent 87f00e84ad
commit 2a2c9cd5ba
No known key found for this signature in database
4 changed files with 193 additions and 233 deletions

View File

@ -1,52 +1,41 @@
from typing import Literal, Optional
import xml.sax.saxutils
from langchain_core.agents import AgentAction
def _escape(xml: str) -> str:
"""Replace XML tags with custom safe delimiters."""
replacements = {
"<tool>": "[[tool]]",
"</tool>": "[[/tool]]",
"<tool_input>": "[[tool_input]]",
"</tool_input>": "[[/tool_input]]",
"<observation>": "[[observation]]",
"</observation>": "[[/observation]]",
}
for orig, repl in replacements.items():
xml = xml.replace(orig, repl)
return xml
def format_xml(
intermediate_steps: list[tuple[AgentAction, str]],
intermediate_steps: List[Tuple[AgentAction, str]],
*,
escape_format: Optional[Literal["minimal"]] = "minimal",
escape_xml: bool = True,
) -> str:
"""Format the intermediate steps as XML.
Escapes all special XML characters in the content, preventing injection of malicious
or malformed XML.
Args:
intermediate_steps: The intermediate steps.
escape_format: The escaping format to use. Currently only 'minimal' is
supported, which replaces XML tags with custom delimiters to prevent
conflicts.
intermediate_steps: The intermediate steps, each a tuple of
(AgentAction, observation).
escape_xml: If True, all XML special characters in the tool name,
tool input, and observation will be escaped (e.g., ``<`` becomes ``&lt;``).
Returns:
The intermediate steps as XML.
A string of concatenated XML blocks representing the intermediate steps.
"""
log = ""
for action, observation in intermediate_steps:
if escape_format == "minimal":
# Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation))
else:
tool = action.tool
tool = str(action.tool)
tool_input = str(action.tool_input)
observation = str(observation)
observation_str = str(observation)
if escape_xml:
entities = {"'": "&apos;", '"': "&quot;"}
tool = xml.sax.saxutils.escape(tool, entities)
tool_input = xml.sax.saxutils.escape(tool_input, entities)
observation_str = xml.sax.saxutils.escape(observation_str, entities)
log += (
f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>"
f"</tool_input><observation>{observation_str}</observation>"
)
return log

View File

@ -1,5 +1,6 @@
import re
from typing import Literal, Optional, Union
import xml.sax.saxutils
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field
@ -7,35 +8,20 @@ from pydantic import Field
from langchain.agents import AgentOutputParser
def _unescape(text: str) -> str:
"""Convert custom tag delimiters back into XML tags."""
replacements = {
"[[tool]]": "<tool>",
"[[/tool]]": "</tool>",
"[[tool_input]]": "<tool_input>",
"[[/tool_input]]": "</tool_input>",
"[[observation]]": "<observation>",
"[[/observation]]": "</observation>",
}
for repl, orig in replacements.items():
text = text.replace(repl, orig)
return text
class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers from XML-formatted agent output.
This parser extracts structured information from XML tags to determine whether
an agent should perform a tool action or provide a final answer. It includes
built-in escaping support to safely handle tool names and inputs
containing XML special characters.
This parser is hardened against XML injection by using standard XML entity
decoding for content within tags. It is designed to work with the corresponding
``format_xml`` function.
Args:
escape_format: The escaping format to use when parsing XML content.
Supports 'minimal' which uses custom delimiters like [[tool]] to replace
XML tags within content, preventing parsing conflicts.
Use 'minimal' if using a corresponding encoding format that uses
the _escape function when formatting the output (e.g., with format_xml).
unescape_xml: If True, the parser will unescape XML special characters in the
content of tags. This should be enabled if the agent's output was formatted
with XML escaping.
If False, the parser will return the raw content as is, which may include
XML special characters like `<`, `>`, and `&`.
Expected formats:
Tool invocation (returns AgentAction):
@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser):
Final answer (returns AgentFinish):
<final_answer>The answer is 4</final_answer>
Note:
Minimal escaping allows tool names containing XML tags to be safely
represented. For example, a tool named "search<tool>nested</tool>" would be
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
unescaped during parsing.
Raises:
ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure.
"""
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
"""The format to use for escaping XML characters.
unescape_xml: bool = Field(default=True)
"""If True, the parser will unescape XML special characters in the content
of tags. This should be enabled if the agent's output was formatted
with XML escaping.
minimal - uses custom delimiters to replace XML tags within content,
preventing parsing conflicts. This is the only supported format currently.
None - no escaping is applied, which may lead to parsing conflicts.
If False, the parser will return the raw content as is,
which may include XML special characters like `<`, `>`, and `&`.
"""
def _extract_tag_content(
self, tag: str, text: str, *, required: bool
) -> Union[str, None]:
"""
Extracts content from a specified XML tag, ensuring it appears at most once.
Args:
tag: The name of the XML tag (e.g., ``'tool'``).
text: The text to parse.
required: If True, a ValueError will be raised if the tag is not found.
Returns:
The unescaped content of the tag as a string, or None if not found
and not required.
Raises:
ValueError: If the tag appears more than once, or if it is required
but not found.
"""
pattern = f"<{tag}>(.*?)</{tag}>"
matches = re.findall(pattern, text, re.DOTALL)
if len(matches) > 1:
raise ValueError(
f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
)
if not matches:
if required:
raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
return None
content = matches[0]
if self.unescape_xml:
entities = {"&apos;": "'", "&quot;": '"'}
return xml.sax.saxutils.unescape(content, entities)
return content
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
# Check for tool invocation first
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
if tool_matches:
if len(tool_matches) != 1:
raise ValueError(
f"Malformed tool invocation: expected exactly one <tool> block, "
f"but found {len(tool_matches)}."
"""
Parses the given text into an AgentAction or AgentFinish object.
"""
# Check for a tool invocation
if "<tool>" in text and "</tool>" in text:
tool = self._extract_tag_content("tool", text, required=True)
# Tool input is optional
tool_input = (
self._extract_tag_content("tool_input", text, required=False) or ""
)
_tool = tool_matches[0]
# Match optional tool input
input_matches = re.findall(
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
)
if len(input_matches) > 1:
raise ValueError(
f"Malformed tool invocation: expected at most one <tool_input> "
f"block, but found {len(input_matches)}."
)
_tool_input = input_matches[0] if input_matches else ""
return AgentAction(tool=tool, tool_input=tool_input, log=text)
# Unescape if minimal escape format is used
if self.escape_format == "minimal":
_tool = _unescape(_tool)
_tool_input = _unescape(_tool_input)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
# Check for final answer
# Check for a final answer
elif "<final_answer>" in text and "</final_answer>" in text:
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
if len(matches) != 1:
msg = (
"Malformed output: expected exactly one "
"<final_answer>...</final_answer> block."
)
raise ValueError(msg)
answer = matches[0]
# Unescape custom delimiters in final answer
if self.escape_format == "minimal":
answer = _unescape(answer)
answer = self._extract_tag_content("final_answer", text, required=True)
return AgentFinish(return_values={"output": answer}, log=text)
# If neither format is found, raise an error
else:
msg = (
"Malformed output: expected either a tool invocation "
"or a final answer in XML format."
raise ValueError(
"Could not parse LLM output. Expected a tool invocation with <tool> "
"and <tool_input> tags, or a final answer with <final_answer> tags."
)
raise ValueError(msg)
def get_format_instructions(self) -> str:
raise NotImplementedError

View File

@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction
from langchain.agents.format_scratchpad.xml import format_xml
def test_single_agent_action_observation() -> None:
# Arrange
def test_format_single_action() -> None:
"""Tests formatting of a single agent action and observation."""
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
intermediate_steps = [(agent_action, "Observation1")]
# Act
result = format_xml(intermediate_steps)
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation>"""
# Assert
assert result == expected_result
def test_multiple_agent_actions_observations() -> None:
# Arrange
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
observation1 = "Observation1"
observation2 = "Observation2"
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
# Act
result = format_xml(intermediate_steps)
# Assert
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation><tool>\
Tool2</tool><tool_input>Input2</tool_input><observation>\
Observation2</observation>"""
assert result == expected_result
def test_empty_list_agent_actions() -> None:
result = format_xml([])
assert result == ""
def test_xml_escaping_minimal() -> None:
"""Test that XML tags in tool names are escaped with minimal format."""
# Arrange
agent_action = AgentAction(
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
)
observation = "Found <observation>result</observation>"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format="minimal")
# Assert - XML tags should be replaced with custom delimiters
expected_result = (
"<tool>search[[tool]]nested[[/tool]]</tool>"
"<tool_input>query<input>test</input></tool_input>"
"<observation>Found [[observation]]result[[/observation]]</observation>"
)
assert result == expected_result
def test_no_escaping() -> None:
"""Test that escaping can be disabled."""
# Arrange
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format=None)
# Assert
expected_result = (
result = format_xml(intermediate_steps, escape_xml=False)
expected = (
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
"<observation>Observation1</observation>"
)
assert result == expected_result
assert result == expected
def test_format_xml_escaping() -> None:
"""Tests that XML special characters in content are properly escaped."""
agent_action = AgentAction(
tool="<tool> & 'some_tool'", tool_input='<query> with "quotes"', log=""
)
observation = "Observed > 5 items"
intermediate_steps = [(agent_action, observation)]
result = format_xml(intermediate_steps, escape_xml=True)
expected = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
"<observation>Observed &gt; 5 items</observation>"
)
assert result == expected

View File

@ -1,69 +1,92 @@
import pytest
from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents.format_scratchpad.xml import format_xml
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
def test_tool_usage() -> None:
def test_parser_tool_usage() -> None:
"""Tests parsing a standard tool invocation."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<tool>search</tool><tool_input>foo</tool_input>"
output = parser.invoke(_input)
expected = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected
def test_parser_final_answer() -> None:
"""Tests parsing a standard final answer."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<final_answer>bar</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected
def test_parser_with_escaped_content() -> None:
"""Tests that the parser correctly unescapes standard XML entities."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
)
output = parser.invoke(_input)
expected = AgentAction(
tool="<tool> & 'some_tool'",
tool_input='<query> with "quotes"',
log=_input,
)
assert output == expected
def test_parser_final_answer_escaped() -> None:
"""Tests parsing a final answer with escaped content."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = "<final_answer>The answer is &gt; 42.</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
assert output == expected
def test_parser_error_on_multiple_tool_tags() -> None:
"""Tests that the parser raises an error for multiple tool tags."""
parser = XMLAgentOutputParser()
# Test when final closing </tool_input> is included
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
# Test when final closing </tool_input> is NOT included
# This happens when it's used as a stop token
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
_input = "<tool>tool1</tool><tool>tool2</tool><tool_input>input</tool_input>"
with pytest.raises(ValueError, match="Found 2 <tool> blocks"):
parser.invoke(_input)
def test_finish() -> None:
def test_parser_error_on_missing_required_tag() -> None:
"""Tests that the parser raises an error if a required tag is missing."""
parser = XMLAgentOutputParser()
# Test when final closing <final_answer> is included
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
_input = "<tool_input>some_input</tool_input>"
# Test when final closing <final_answer> is NOT included
# This happens when it's used as a stop token
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
with pytest.raises(ValueError, match="Could not parse LLM output"):
parser.invoke(_input)
def test_malformed_xml_with_nested_tags() -> None:
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
from langchain.agents.format_scratchpad.xml import format_xml
def test_integration_format_and_parse() -> None:
"""An integration test to ensure formatting and parsing work together."""
parser = XMLAgentOutputParser()
agent_action = AgentAction(tool="<special-tool>", tool_input="input with &", log="")
intermediate_steps = [(agent_action, "observation with >")]
# Create an AgentAction with XML tags in the tool name
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
# 1. Format the data, escaping the XML
formatted_xml = format_xml(intermediate_steps, escape_xml=True)
# The format_xml function should escape the XML tags using custom delimiters
formatted_xml = format_xml([(action, "observation")])
# Extract the tool call part for parsing
tool_part = formatted_xml.split("<observation>")[0]
# Extract just the tool part for parsing
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
# Now test that the parser can handle the escaped XML
parser = XMLAgentOutputParser(escape_format="minimal")
# 2. Parse the formatted data
output = parser.invoke(tool_part)
# The parser should unescape and extract the original tool name
expected_output = AgentAction(
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
# 3. Assert that the original data is recovered
expected_action = AgentAction(
tool="<special-tool>", tool_input="input with &", log=tool_part
)
assert output == expected_output
def test_no_escaping() -> None:
"""Test parser with escaping disabled."""
parser = XMLAgentOutputParser(escape_format=None)
# Test with regular tool name (no XML tags)
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
assert output == expected_action