langchain[patch]: harden xml parser for xmloutput agent (#31859)

Harden the default implementation of the XML parser for the agent

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Eugene Yurtsev 2025-07-08 10:57:21 -04:00 committed by GitHub
parent 3f839d566a
commit 83d8be756a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 205 additions and 28 deletions

View File

@ -1,21 +1,52 @@
from typing import Literal, Optional
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
def _escape(xml: str) -> str:
"""Replace XML tags with custom safe delimiters."""
replacements = {
"<tool>": "[[tool]]",
"</tool>": "[[/tool]]",
"<tool_input>": "[[tool_input]]",
"</tool_input>": "[[/tool_input]]",
"<observation>": "[[observation]]",
"</observation>": "[[/observation]]",
}
for orig, repl in replacements.items():
xml = xml.replace(orig, repl)
return xml
def format_xml( def format_xml(
intermediate_steps: list[tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
*,
escape_format: Optional[Literal["minimal"]] = "minimal",
) -> str: ) -> str:
"""Format the intermediate steps as XML. """Format the intermediate steps as XML.
Args: Args:
intermediate_steps: The intermediate steps. intermediate_steps: The intermediate steps.
escape_format: The escaping format to use. Currently only 'minimal' is
supported, which replaces XML tags with custom delimiters to prevent
conflicts.
Returns: Returns:
The intermediate steps as XML. The intermediate steps as XML.
""" """
log = "" log = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
if escape_format == "minimal":
# Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation))
else:
tool = action.tool
tool_input = str(action.tool_input)
observation = str(observation)
log += ( log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}" f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>" f"</tool_input><observation>{observation}</observation>"
) )
return log return log

View File

@ -1,47 +1,119 @@
from typing import Union import re
from typing import Literal, Optional, Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field
from langchain.agents import AgentOutputParser from langchain.agents import AgentOutputParser
def _unescape(text: str) -> str:
"""Convert custom tag delimiters back into XML tags."""
replacements = {
"[[tool]]": "<tool>",
"[[/tool]]": "</tool>",
"[[tool_input]]": "<tool_input>",
"[[/tool_input]]": "</tool_input>",
"[[observation]]": "<observation>",
"[[/observation]]": "</observation>",
}
for repl, orig in replacements.items():
text = text.replace(repl, orig)
return text
class XMLAgentOutputParser(AgentOutputParser): class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers in XML format. """Parses tool invocations and final answers from XML-formatted agent output.
Expects output to be in one of two formats. This parser extracts structured information from XML tags to determine whether
an agent should perform a tool action or provide a final answer. It includes
built-in escaping support to safely handle tool names and inputs
containing XML special characters.
If the output signals that an action should be taken, Args:
should be in the below format. This will result in an AgentAction escape_format: The escaping format to use when parsing XML content.
being returned. Supports 'minimal' which uses custom delimiters like [[tool]] to replace
XML tags within content, preventing parsing conflicts.
Use 'minimal' if using a corresponding encoding format that uses
the _escape function when formatting the output (e.g., with format_xml).
``` Expected formats:
<tool>search</tool> Tool invocation (returns AgentAction):
<tool_input>what is 2 + 2</tool_input> <tool>search</tool>
``` <tool_input>what is 2 + 2</tool_input>
If the output signals that a final answer should be given, Final answer (returns AgentFinish):
should be in the below format. This will result in an AgentFinish <final_answer>The answer is 4</final_answer>
being returned.
``` Note:
<final_answer>Foo</final_answer> Minimal escaping allows tool names containing XML tags to be safely
``` represented. For example, a tool named "search<tool>nested</tool>" would be
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
unescaped during parsing.
Raises:
ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure.
"""
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
"""The format to use for escaping XML characters.
minimal - uses custom delimiters to replace XML tags within content,
preventing parsing conflicts. This is the only supported format currently.
None - no escaping is applied, which may lead to parsing conflicts.
""" """
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text: # Check for tool invocation first
tool, tool_input = text.split("</tool>") tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
_tool = tool.split("<tool>")[1] if tool_matches:
_tool_input = tool_input.split("<tool_input>")[1] if len(tool_matches) != 1:
if "</tool_input>" in _tool_input: msg = (
_tool_input = _tool_input.split("</tool_input>")[0] f"Malformed tool invocation: expected exactly one <tool> block, "
f"but found {len(tool_matches)}."
)
raise ValueError(msg)
_tool = tool_matches[0]
# Match optional tool input
input_matches = re.findall(
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
)
if len(input_matches) > 1:
msg = (
f"Malformed tool invocation: expected at most one <tool_input> "
f"block, but found {len(input_matches)}."
)
raise ValueError(msg)
_tool_input = input_matches[0] if input_matches else ""
# Unescape if minimal escape format is used
if self.escape_format == "minimal":
_tool = _unescape(_tool)
_tool_input = _unescape(_tool_input)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text) return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
if "<final_answer>" in text: # Check for final answer
_, answer = text.split("<final_answer>") if "<final_answer>" in text and "</final_answer>" in text:
if "</final_answer>" in answer: matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
answer = answer.split("</final_answer>")[0] if len(matches) != 1:
msg = (
"Malformed output: expected exactly one "
"<final_answer>...</final_answer> block."
)
raise ValueError(msg)
answer = matches[0]
# Unescape custom delimiters in final answer
if self.escape_format == "minimal":
answer = _unescape(answer)
return AgentFinish(return_values={"output": answer}, log=text) return AgentFinish(return_values={"output": answer}, log=text)
raise ValueError msg = (
"Malformed output: expected either a tool invocation "
"or a final answer in XML format."
)
raise ValueError(msg)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
raise NotImplementedError raise NotImplementedError

View File

@ -39,3 +39,42 @@ Observation2</observation>"""
def test_empty_list_agent_actions() -> None: def test_empty_list_agent_actions() -> None:
result = format_xml([]) result = format_xml([])
assert result == "" assert result == ""
def test_xml_escaping_minimal() -> None:
"""Test that XML tags in tool names are escaped with minimal format."""
# Arrange
agent_action = AgentAction(
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
)
observation = "Found <observation>result</observation>"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format="minimal")
# Assert - XML tags should be replaced with custom delimiters
expected_result = (
"<tool>search[[tool]]nested[[/tool]]</tool>"
"<tool_input>query<input>test</input></tool_input>"
"<observation>Found [[observation]]result[[/observation]]</observation>"
)
assert result == expected_result
def test_no_escaping() -> None:
"""Test that escaping can be disabled."""
# Arrange
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format=None)
# Assert
expected_result = (
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
"<observation>Observation1</observation>"
)
assert result == expected_result

View File

@ -32,3 +32,38 @@ def test_finish() -> None:
output = parser.invoke(_input) output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input) expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output assert output == expected_output
def test_malformed_xml_with_nested_tags() -> None:
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
from langchain.agents.format_scratchpad.xml import format_xml
# Create an AgentAction with XML tags in the tool name
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
# The format_xml function should escape the XML tags using custom delimiters
formatted_xml = format_xml([(action, "observation")])
# Extract just the tool part for parsing
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
# Now test that the parser can handle the escaped XML
parser = XMLAgentOutputParser(escape_format="minimal")
output = parser.invoke(tool_part)
# The parser should unescape and extract the original tool name
expected_output = AgentAction(
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
)
assert output == expected_output
def test_no_escaping() -> None:
"""Test parser with escaping disabled."""
parser = XMLAgentOutputParser(escape_format=None)
# Test with regular tool name (no XML tags)
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output