refactor: streamline XML formatting and parsing functions for improved security and clarity

This commit is contained in:
Mason Daugherty 2025-07-06 23:08:17 -04:00
parent 87f00e84ad
commit 2a2c9cd5ba
No known key found for this signature in database
4 changed files with 193 additions and 233 deletions

View File

@ -1,52 +1,41 @@
from typing import Literal, Optional import xml.sax.saxutils
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
def _escape(xml: str) -> str:
"""Replace XML tags with custom safe delimiters."""
replacements = {
"<tool>": "[[tool]]",
"</tool>": "[[/tool]]",
"<tool_input>": "[[tool_input]]",
"</tool_input>": "[[/tool_input]]",
"<observation>": "[[observation]]",
"</observation>": "[[/observation]]",
}
for orig, repl in replacements.items():
xml = xml.replace(orig, repl)
return xml
def format_xml( def format_xml(
intermediate_steps: list[tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
*, *,
escape_format: Optional[Literal["minimal"]] = "minimal", escape_xml: bool = True,
) -> str: ) -> str:
"""Format the intermediate steps as XML. """Format the intermediate steps as XML.
Escapes all special XML characters in the content, preventing injection of malicious
or malformed XML.
Args: Args:
intermediate_steps: The intermediate steps. intermediate_steps: The intermediate steps, each a tuple of
escape_format: The escaping format to use. Currently only 'minimal' is (AgentAction, observation).
supported, which replaces XML tags with custom delimiters to prevent escape_xml: If True, all XML special characters in the tool name,
conflicts. tool input, and observation will be escaped (e.g., ``<`` becomes ``&lt;``).
Returns: Returns:
The intermediate steps as XML. A string of concatenated XML blocks representing the intermediate steps.
""" """
log = "" log = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
if escape_format == "minimal": tool = str(action.tool)
# Escape XML tags in tool names and inputs using custom delimiters tool_input = str(action.tool_input)
tool = _escape(action.tool) observation_str = str(observation)
tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation)) if escape_xml:
else: entities = {"'": "&apos;", '"': "&quot;"}
tool = action.tool tool = xml.sax.saxutils.escape(tool, entities)
tool_input = str(action.tool_input) tool_input = xml.sax.saxutils.escape(tool_input, entities)
observation = str(observation) observation_str = xml.sax.saxutils.escape(observation_str, entities)
log += ( log += (
f"<tool>{tool}</tool><tool_input>{tool_input}" f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>" f"</tool_input><observation>{observation_str}</observation>"
) )
return log return log

View File

@ -1,5 +1,6 @@
import re import re
from typing import Literal, Optional, Union import xml.sax.saxutils
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field from pydantic import Field
@ -7,35 +8,20 @@ from pydantic import Field
from langchain.agents import AgentOutputParser from langchain.agents import AgentOutputParser
def _unescape(text: str) -> str:
"""Convert custom tag delimiters back into XML tags."""
replacements = {
"[[tool]]": "<tool>",
"[[/tool]]": "</tool>",
"[[tool_input]]": "<tool_input>",
"[[/tool_input]]": "</tool_input>",
"[[observation]]": "<observation>",
"[[/observation]]": "</observation>",
}
for repl, orig in replacements.items():
text = text.replace(repl, orig)
return text
class XMLAgentOutputParser(AgentOutputParser): class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers from XML-formatted agent output. """Parses tool invocations and final answers from XML-formatted agent output.
This parser extracts structured information from XML tags to determine whether This parser is hardened against XML injection by using standard XML entity
an agent should perform a tool action or provide a final answer. It includes decoding for content within tags. It is designed to work with the corresponding
built-in escaping support to safely handle tool names and inputs ``format_xml`` function.
containing XML special characters.
Args: Args:
escape_format: The escaping format to use when parsing XML content. unescape_xml: If True, the parser will unescape XML special characters in the
Supports 'minimal' which uses custom delimiters like [[tool]] to replace content of tags. This should be enabled if the agent's output was formatted
XML tags within content, preventing parsing conflicts. with XML escaping.
Use 'minimal' if using a corresponding encoding format that uses
the _escape function when formatting the output (e.g., with format_xml). If False, the parser will return the raw content as is, which may include
XML special characters like `<`, `>`, and `&`.
Expected formats: Expected formats:
Tool invocation (returns AgentAction): Tool invocation (returns AgentAction):
@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser):
Final answer (returns AgentFinish): Final answer (returns AgentFinish):
<final_answer>The answer is 4</final_answer> <final_answer>The answer is 4</final_answer>
Note:
Minimal escaping allows tool names containing XML tags to be safely
represented. For example, a tool named "search<tool>nested</tool>" would be
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
unescaped during parsing.
Raises: Raises:
ValueError: If the input doesn't match either expected XML format or ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure. contains malformed XML structure.
""" """
escape_format: Optional[Literal["minimal"]] = Field(default="minimal") unescape_xml: bool = Field(default=True)
"""The format to use for escaping XML characters. """If True, the parser will unescape XML special characters in the content
of tags. This should be enabled if the agent's output was formatted
with XML escaping.
minimal - uses custom delimiters to replace XML tags within content, If False, the parser will return the raw content as is,
preventing parsing conflicts. This is the only supported format currently. which may include XML special characters like `<`, `>`, and `&`.
None - no escaping is applied, which may lead to parsing conflicts.
""" """
def _extract_tag_content(
self, tag: str, text: str, *, required: bool
) -> Union[str, None]:
"""
Extracts content from a specified XML tag, ensuring it appears at most once.
Args:
tag: The name of the XML tag (e.g., ``'tool'``).
text: The text to parse.
required: If True, a ValueError will be raised if the tag is not found.
Returns:
The unescaped content of the tag as a string, or None if not found
and not required.
Raises:
ValueError: If the tag appears more than once, or if it is required
but not found.
"""
pattern = f"<{tag}>(.*?)</{tag}>"
matches = re.findall(pattern, text, re.DOTALL)
if len(matches) > 1:
raise ValueError(
f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
)
if not matches:
if required:
raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
return None
content = matches[0]
if self.unescape_xml:
entities = {"&apos;": "'", "&quot;": '"'}
return xml.sax.saxutils.unescape(content, entities)
return content
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
# Check for tool invocation first """
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL) Parses the given text into an AgentAction or AgentFinish object.
if tool_matches: """
if len(tool_matches) != 1: # Check for a tool invocation
raise ValueError( if "<tool>" in text and "</tool>" in text:
f"Malformed tool invocation: expected exactly one <tool> block, " tool = self._extract_tag_content("tool", text, required=True)
f"but found {len(tool_matches)}." # Tool input is optional
) tool_input = (
_tool = tool_matches[0] self._extract_tag_content("tool_input", text, required=False) or ""
# Match optional tool input
input_matches = re.findall(
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
) )
if len(input_matches) > 1:
raise ValueError(
f"Malformed tool invocation: expected at most one <tool_input> "
f"block, but found {len(input_matches)}."
)
_tool_input = input_matches[0] if input_matches else ""
# Unescape if minimal escape format is used return AgentAction(tool=tool, tool_input=tool_input, log=text)
if self.escape_format == "minimal":
_tool = _unescape(_tool)
_tool_input = _unescape(_tool_input)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text) # Check for a final answer
# Check for final answer
elif "<final_answer>" in text and "</final_answer>" in text: elif "<final_answer>" in text and "</final_answer>" in text:
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL) answer = self._extract_tag_content("final_answer", text, required=True)
if len(matches) != 1:
msg = (
"Malformed output: expected exactly one "
"<final_answer>...</final_answer> block."
)
raise ValueError(msg)
answer = matches[0]
# Unescape custom delimiters in final answer
if self.escape_format == "minimal":
answer = _unescape(answer)
return AgentFinish(return_values={"output": answer}, log=text) return AgentFinish(return_values={"output": answer}, log=text)
# If neither format is found, raise an error
else: else:
msg = ( raise ValueError(
"Malformed output: expected either a tool invocation " "Could not parse LLM output. Expected a tool invocation with <tool> "
"or a final answer in XML format." "and <tool_input> tags, or a final answer with <final_answer> tags."
) )
raise ValueError(msg)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
raise NotImplementedError raise NotImplementedError

View File

@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction
from langchain.agents.format_scratchpad.xml import format_xml from langchain.agents.format_scratchpad.xml import format_xml
def test_single_agent_action_observation() -> None: def test_format_single_action() -> None:
# Arrange """Tests formatting of a single agent action and observation."""
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1") agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
observation = "Observation1" intermediate_steps = [(agent_action, "Observation1")]
intermediate_steps = [(agent_action, observation)]
# Act result = format_xml(intermediate_steps, escape_xml=False)
result = format_xml(intermediate_steps) expected = (
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation>"""
# Assert
assert result == expected_result
def test_multiple_agent_actions_observations() -> None:
# Arrange
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
observation1 = "Observation1"
observation2 = "Observation2"
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
# Act
result = format_xml(intermediate_steps)
# Assert
expected_result = """<tool>Tool1</tool><tool_input>Input1\
</tool_input><observation>Observation1</observation><tool>\
Tool2</tool><tool_input>Input2</tool_input><observation>\
Observation2</observation>"""
assert result == expected_result
def test_empty_list_agent_actions() -> None:
result = format_xml([])
assert result == ""
def test_xml_escaping_minimal() -> None:
"""Test that XML tags in tool names are escaped with minimal format."""
# Arrange
agent_action = AgentAction(
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
)
observation = "Found <observation>result</observation>"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format="minimal")
# Assert - XML tags should be replaced with custom delimiters
expected_result = (
"<tool>search[[tool]]nested[[/tool]]</tool>"
"<tool_input>query<input>test</input></tool_input>"
"<observation>Found [[observation]]result[[/observation]]</observation>"
)
assert result == expected_result
def test_no_escaping() -> None:
"""Test that escaping can be disabled."""
# Arrange
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format=None)
# Assert
expected_result = (
"<tool>Tool1</tool><tool_input>Input1</tool_input>" "<tool>Tool1</tool><tool_input>Input1</tool_input>"
"<observation>Observation1</observation>" "<observation>Observation1</observation>"
) )
assert result == expected_result assert result == expected
def test_format_xml_escaping() -> None:
"""Tests that XML special characters in content are properly escaped."""
agent_action = AgentAction(
tool="<tool> & 'some_tool'", tool_input='<query> with "quotes"', log=""
)
observation = "Observed > 5 items"
intermediate_steps = [(agent_action, observation)]
result = format_xml(intermediate_steps, escape_xml=True)
expected = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
"<observation>Observed &gt; 5 items</observation>"
)
assert result == expected

View File

@ -1,69 +1,92 @@
import pytest
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents.format_scratchpad.xml import format_xml
from langchain.agents.output_parsers.xml import XMLAgentOutputParser from langchain.agents.output_parsers.xml import XMLAgentOutputParser
def test_tool_usage() -> None: def test_parser_tool_usage() -> None:
"""Tests parsing a standard tool invocation."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<tool>search</tool><tool_input>foo</tool_input>"
output = parser.invoke(_input)
expected = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected
def test_parser_final_answer() -> None:
"""Tests parsing a standard final answer."""
parser = XMLAgentOutputParser(unescape_xml=False)
_input = "<final_answer>bar</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected
def test_parser_with_escaped_content() -> None:
"""Tests that the parser correctly unescapes standard XML entities."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = (
"<tool>&lt;tool&gt; &amp; &apos;some_tool&apos;</tool>"
"<tool_input>&lt;query&gt; with &quot;quotes&quot;</tool_input>"
)
output = parser.invoke(_input)
expected = AgentAction(
tool="<tool> & 'some_tool'",
tool_input='<query> with "quotes"',
log=_input,
)
assert output == expected
def test_parser_final_answer_escaped() -> None:
"""Tests parsing a final answer with escaped content."""
parser = XMLAgentOutputParser(unescape_xml=True)
_input = "<final_answer>The answer is &gt; 42.</final_answer>"
output = parser.invoke(_input)
expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
assert output == expected
def test_parser_error_on_multiple_tool_tags() -> None:
"""Tests that the parser raises an error for multiple tool tags."""
parser = XMLAgentOutputParser() parser = XMLAgentOutputParser()
# Test when final closing </tool_input> is included _input = "<tool>tool1</tool><tool>tool2</tool><tool_input>input</tool_input>"
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input) with pytest.raises(ValueError, match="Found 2 <tool> blocks"):
expected_output = AgentAction(tool="search", tool_input="foo", log=_input) parser.invoke(_input)
assert output == expected_output
# Test when final closing </tool_input> is NOT included
# This happens when it's used as a stop token
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output
def test_finish() -> None: def test_parser_error_on_missing_required_tag() -> None:
"""Tests that the parser raises an error if a required tag is missing."""
parser = XMLAgentOutputParser() parser = XMLAgentOutputParser()
# Test when final closing <final_answer> is included _input = "<tool_input>some_input</tool_input>"
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
# Test when final closing <final_answer> is NOT included with pytest.raises(ValueError, match="Could not parse LLM output"):
# This happens when it's used as a stop token parser.invoke(_input)
_input = """<final_answer>bar</final_answer>"""
output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output
def test_malformed_xml_with_nested_tags() -> None: def test_integration_format_and_parse() -> None:
"""Test handling of tool names with XML tags via format_xml minimal escaping.""" """An integration test to ensure formatting and parsing work together."""
from langchain.agents.format_scratchpad.xml import format_xml parser = XMLAgentOutputParser()
agent_action = AgentAction(tool="<special-tool>", tool_input="input with &", log="")
intermediate_steps = [(agent_action, "observation with >")]
# Create an AgentAction with XML tags in the tool name # 1. Format the data, escaping the XML
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="") formatted_xml = format_xml(intermediate_steps, escape_xml=True)
# The format_xml function should escape the XML tags using custom delimiters # Extract the tool call part for parsing
formatted_xml = format_xml([(action, "observation")]) tool_part = formatted_xml.split("<observation>")[0]
# Extract just the tool part for parsing # 2. Parse the formatted data
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
# Now test that the parser can handle the escaped XML
parser = XMLAgentOutputParser(escape_format="minimal")
output = parser.invoke(tool_part) output = parser.invoke(tool_part)
# The parser should unescape and extract the original tool name # 3. Assert that the original data is recovered
expected_output = AgentAction( expected_action = AgentAction(
tool="search<tool>nested</tool>", tool_input="query", log=tool_part tool="<special-tool>", tool_input="input with &", log=tool_part
) )
assert output == expected_output assert output == expected_action
def test_no_escaping() -> None:
"""Test parser with escaping disabled."""
parser = XMLAgentOutputParser(escape_format=None)
# Test with regular tool name (no XML tags)
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output