This commit is contained in:
Eugene Yurtsev 2025-07-03 18:28:43 -04:00
parent 6a5073b227
commit 5d6f03cc34
4 changed files with 177 additions and 22 deletions

View File

@ -1,21 +1,52 @@
from typing import Literal
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
def _escape(xml: str) -> str:
"""Replace XML tags with custom safe delimiters."""
replacements = {
"<tool>": "[[tool]]",
"</tool>": "[[/tool]]",
"<tool_input>": "[[tool_input]]",
"</tool_input>": "[[/tool_input]]",
"<observation>": "[[observation]]",
"</observation>": "[[/observation]]",
}
for orig, repl in replacements.items():
xml = xml.replace(orig, repl)
return xml
def format_xml( def format_xml(
intermediate_steps: list[tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
*,
escape_format: Literal["minimal"] | None = "minimal",
) -> str: ) -> str:
"""Format the intermediate steps as XML. """Format the intermediate steps as XML.
Args: Args:
intermediate_steps: The intermediate steps. intermediate_steps: The intermediate steps.
escape_format: The escaping format to use. Currently only 'minimal' is
supported, which replaces XML tags with custom delimiters to prevent
conflicts.
Returns: Returns:
The intermediate steps as XML. The intermediate steps as XML.
""" """
log = "" log = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
if escape_format == "minimal":
# Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input))
observation = _escape(str(observation))
else:
tool = action.tool
tool_input = str(action.tool_input)
observation = str(observation)
log += ( log += (
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}" f"<tool>{tool}</tool><tool_input>{tool_input}"
f"</tool_input><observation>{observation}</observation>" f"</tool_input><observation>{observation}</observation>"
) )
return log return log

View File

@ -1,31 +1,68 @@
from typing import Union import re
from typing import Literal, Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field
from langchain.agents import AgentOutputParser from langchain.agents import AgentOutputParser
def _unescape(text: str) -> str:
"""Convert custom tag delimiters back into XML tags."""
replacements = {
"[[tool]]": "<tool>",
"[[/tool]]": "</tool>",
"[[tool_input]]": "<tool_input>",
"[[/tool_input]]": "</tool_input>",
"[[observation]]": "<observation>",
"[[/observation]]": "</observation>",
}
for repl, orig in replacements.items():
text = text.replace(repl, orig)
return text
class XMLAgentOutputParser(AgentOutputParser): class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers in XML format. """Parses tool invocations and final answers from XML-formatted agent output.
Expects output to be in one of two formats. This parser extracts structured information from XML tags to determine whether
an agent should perform a tool action or provide a final answer. It includes
built-in escaping support to safely handle tool names and inputs
containing XML special characters.
If the output signals that an action should be taken, Args:
should be in the below format. This will result in an AgentAction escape_format: The escaping format to use when parsing XML content.
being returned. Supports 'minimal' which uses custom delimiters like [[tool]] to replace
XML tags within content, preventing parsing conflicts.
Use 'minimal' if using a corresponding encoding format that uses
the _escape function when formatting the output (e.g., with format_xml).
``` Expected formats:
Tool invocation (returns AgentAction):
<tool>search</tool> <tool>search</tool>
<tool_input>what is 2 + 2</tool_input> <tool_input>what is 2 + 2</tool_input>
```
If the output signals that a final answer should be given, Final answer (returns AgentFinish):
should be in the below format. This will result in an AgentFinish <final_answer>The answer is 4</final_answer>
being returned.
``` Note:
<final_answer>Foo</final_answer> Minimal escaping allows tool names containing XML tags to be safely
``` represented. For example, a tool named "search<tool>nested</tool>" would be
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
unescaped during parsing.
Raises:
ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure.
"""
escape_format: Literal["minimal"] | None = Field(default="minimal")
"""The format to use for escaping XML characters.
minimal - uses custom delimiters to replace XML tags within content,
preventing parsing conflicts. This is the only supported format currently.
None - no escaping is applied, which may lead to parsing conflicts.
""" """
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
@ -35,14 +72,27 @@ class XMLAgentOutputParser(AgentOutputParser):
_tool_input = tool_input.split("<tool_input>")[1] _tool_input = tool_input.split("<tool_input>")[1]
if "</tool_input>" in _tool_input: if "</tool_input>" in _tool_input:
_tool_input = _tool_input.split("</tool_input>")[0] _tool_input = _tool_input.split("</tool_input>")[0]
# Unescape custom delimiters in tool name and input
if self.escape_format == "minimal":
_tool = _unescape(_tool)
_tool_input = _unescape(_tool_input)
return AgentAction(tool=_tool, tool_input=_tool_input, log=text) return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
elif "<final_answer>" in text: elif "<final_answer>" in text and "</final_answer>" in text:
_, answer = text.split("<final_answer>") matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
if "</final_answer>" in answer: if len(matches) != 1:
answer = answer.split("</final_answer>")[0] msg = (
"Malformed output: expected exactly one "
"<final_answer>...</final_answer> block."
)
raise ValueError(msg)
answer = matches[0]
return AgentFinish(return_values={"output": answer}, log=text) return AgentFinish(return_values={"output": answer}, log=text)
else: else:
raise ValueError msg = (
"Malformed output: expected either a tool invocation "
"or a final answer in XML format."
)
raise ValueError(msg)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
raise NotImplementedError raise NotImplementedError

View File

@ -39,3 +39,42 @@ Observation2</observation>"""
def test_empty_list_agent_actions() -> None: def test_empty_list_agent_actions() -> None:
result = format_xml([]) result = format_xml([])
assert result == "" assert result == ""
def test_xml_escaping_minimal() -> None:
"""Test that XML tags in tool names are escaped with minimal format."""
# Arrange
agent_action = AgentAction(
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
)
observation = "Found <observation>result</observation>"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format="minimal")
# Assert - XML tags should be replaced with custom delimiters
expected_result = (
"<tool>search[[tool]]nested[[/tool]]</tool>"
"<tool_input>query<input>test</input></tool_input>"
"<observation>Found [[observation]]result[[/observation]]</observation>"
)
assert result == expected_result
def test_no_escaping() -> None:
"""Test that escaping can be disabled."""
# Arrange
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
observation = "Observation1"
intermediate_steps = [(agent_action, observation)]
# Act
result = format_xml(intermediate_steps, escape_format=None)
# Assert
expected_result = (
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
"<observation>Observation1</observation>"
)
assert result == expected_result

View File

@ -32,3 +32,38 @@ def test_finish() -> None:
output = parser.invoke(_input) output = parser.invoke(_input)
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input) expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
assert output == expected_output assert output == expected_output
def test_malformed_xml_with_nested_tags() -> None:
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
from langchain.agents.format_scratchpad.xml import format_xml
# Create an AgentAction with XML tags in the tool name
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
# The format_xml function should escape the XML tags using custom delimiters
formatted_xml = format_xml([(action, "observation")])
# Extract just the tool part for parsing
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
# Now test that the parser can handle the escaped XML
parser = XMLAgentOutputParser(escape_format="minimal")
output = parser.invoke(tool_part)
# The parser should unescape and extract the original tool name
expected_output = AgentAction(
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
)
assert output == expected_output
def test_no_escaping() -> None:
"""Test parser with escaping disabled."""
parser = XMLAgentOutputParser(escape_format=None)
# Test with regular tool name (no XML tags)
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
output = parser.invoke(_input)
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
assert output == expected_output