mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
refactor: streamline XML formatting and parsing functions for improved security and clarity
This commit is contained in:
parent
87f00e84ad
commit
2a2c9cd5ba
@ -1,52 +1,41 @@
|
|||||||
from typing import Literal, Optional
|
import xml.sax.saxutils
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
|
|
||||||
|
|
||||||
def _escape(xml: str) -> str:
|
|
||||||
"""Replace XML tags with custom safe delimiters."""
|
|
||||||
replacements = {
|
|
||||||
"<tool>": "[[tool]]",
|
|
||||||
"</tool>": "[[/tool]]",
|
|
||||||
"<tool_input>": "[[tool_input]]",
|
|
||||||
"</tool_input>": "[[/tool_input]]",
|
|
||||||
"<observation>": "[[observation]]",
|
|
||||||
"</observation>": "[[/observation]]",
|
|
||||||
}
|
|
||||||
for orig, repl in replacements.items():
|
|
||||||
xml = xml.replace(orig, repl)
|
|
||||||
return xml
|
|
||||||
|
|
||||||
|
|
||||||
def format_xml(
|
def format_xml(
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
*,
|
*,
|
||||||
escape_format: Optional[Literal["minimal"]] = "minimal",
|
escape_xml: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Format the intermediate steps as XML.
|
"""Format the intermediate steps as XML.
|
||||||
|
|
||||||
|
Escapes all special XML characters in the content, preventing injection of malicious
|
||||||
|
or malformed XML.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
intermediate_steps: The intermediate steps.
|
intermediate_steps: The intermediate steps, each a tuple of
|
||||||
escape_format: The escaping format to use. Currently only 'minimal' is
|
(AgentAction, observation).
|
||||||
supported, which replaces XML tags with custom delimiters to prevent
|
escape_xml: If True, all XML special characters in the tool name,
|
||||||
conflicts.
|
tool input, and observation will be escaped (e.g., ``<`` becomes ``<``).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The intermediate steps as XML.
|
A string of concatenated XML blocks representing the intermediate steps.
|
||||||
"""
|
"""
|
||||||
log = ""
|
log = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
if escape_format == "minimal":
|
tool = str(action.tool)
|
||||||
# Escape XML tags in tool names and inputs using custom delimiters
|
tool_input = str(action.tool_input)
|
||||||
tool = _escape(action.tool)
|
observation_str = str(observation)
|
||||||
tool_input = _escape(str(action.tool_input))
|
|
||||||
observation = _escape(str(observation))
|
if escape_xml:
|
||||||
else:
|
entities = {"'": "'", '"': """}
|
||||||
tool = action.tool
|
tool = xml.sax.saxutils.escape(tool, entities)
|
||||||
tool_input = str(action.tool_input)
|
tool_input = xml.sax.saxutils.escape(tool_input, entities)
|
||||||
observation = str(observation)
|
observation_str = xml.sax.saxutils.escape(observation_str, entities)
|
||||||
|
|
||||||
log += (
|
log += (
|
||||||
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
||||||
f"</tool_input><observation>{observation}</observation>"
|
f"</tool_input><observation>{observation_str}</observation>"
|
||||||
)
|
)
|
||||||
return log
|
return log
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Literal, Optional, Union
|
import xml.sax.saxutils
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -7,35 +8,20 @@ from pydantic import Field
|
|||||||
from langchain.agents import AgentOutputParser
|
from langchain.agents import AgentOutputParser
|
||||||
|
|
||||||
|
|
||||||
def _unescape(text: str) -> str:
|
|
||||||
"""Convert custom tag delimiters back into XML tags."""
|
|
||||||
replacements = {
|
|
||||||
"[[tool]]": "<tool>",
|
|
||||||
"[[/tool]]": "</tool>",
|
|
||||||
"[[tool_input]]": "<tool_input>",
|
|
||||||
"[[/tool_input]]": "</tool_input>",
|
|
||||||
"[[observation]]": "<observation>",
|
|
||||||
"[[/observation]]": "</observation>",
|
|
||||||
}
|
|
||||||
for repl, orig in replacements.items():
|
|
||||||
text = text.replace(repl, orig)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class XMLAgentOutputParser(AgentOutputParser):
|
class XMLAgentOutputParser(AgentOutputParser):
|
||||||
"""Parses tool invocations and final answers from XML-formatted agent output.
|
"""Parses tool invocations and final answers from XML-formatted agent output.
|
||||||
|
|
||||||
This parser extracts structured information from XML tags to determine whether
|
This parser is hardened against XML injection by using standard XML entity
|
||||||
an agent should perform a tool action or provide a final answer. It includes
|
decoding for content within tags. It is designed to work with the corresponding
|
||||||
built-in escaping support to safely handle tool names and inputs
|
``format_xml`` function.
|
||||||
containing XML special characters.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
escape_format: The escaping format to use when parsing XML content.
|
unescape_xml: If True, the parser will unescape XML special characters in the
|
||||||
Supports 'minimal' which uses custom delimiters like [[tool]] to replace
|
content of tags. This should be enabled if the agent's output was formatted
|
||||||
XML tags within content, preventing parsing conflicts.
|
with XML escaping.
|
||||||
Use 'minimal' if using a corresponding encoding format that uses
|
|
||||||
the _escape function when formatting the output (e.g., with format_xml).
|
If False, the parser will return the raw content as is, which may include
|
||||||
|
XML special characters like `<`, `>`, and `&`.
|
||||||
|
|
||||||
Expected formats:
|
Expected formats:
|
||||||
Tool invocation (returns AgentAction):
|
Tool invocation (returns AgentAction):
|
||||||
@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser):
|
|||||||
Final answer (returns AgentFinish):
|
Final answer (returns AgentFinish):
|
||||||
<final_answer>The answer is 4</final_answer>
|
<final_answer>The answer is 4</final_answer>
|
||||||
|
|
||||||
Note:
|
|
||||||
Minimal escaping allows tool names containing XML tags to be safely
|
|
||||||
represented. For example, a tool named "search<tool>nested</tool>" would be
|
|
||||||
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
|
|
||||||
unescaped during parsing.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input doesn't match either expected XML format or
|
ValueError: If the input doesn't match either expected XML format or
|
||||||
contains malformed XML structure.
|
contains malformed XML structure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
|
unescape_xml: bool = Field(default=True)
|
||||||
"""The format to use for escaping XML characters.
|
"""If True, the parser will unescape XML special characters in the content
|
||||||
|
of tags. This should be enabled if the agent's output was formatted
|
||||||
minimal - uses custom delimiters to replace XML tags within content,
|
with XML escaping.
|
||||||
preventing parsing conflicts. This is the only supported format currently.
|
|
||||||
|
If False, the parser will return the raw content as is,
|
||||||
None - no escaping is applied, which may lead to parsing conflicts.
|
which may include XML special characters like `<`, `>`, and `&`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _extract_tag_content(
|
||||||
|
self, tag: str, text: str, *, required: bool
|
||||||
|
) -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Extracts content from a specified XML tag, ensuring it appears at most once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: The name of the XML tag (e.g., ``'tool'``).
|
||||||
|
text: The text to parse.
|
||||||
|
required: If True, a ValueError will be raised if the tag is not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unescaped content of the tag as a string, or None if not found
|
||||||
|
and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the tag appears more than once, or if it is required
|
||||||
|
but not found.
|
||||||
|
"""
|
||||||
|
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
if required:
|
||||||
|
raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = matches[0]
|
||||||
|
if self.unescape_xml:
|
||||||
|
entities = {"'": "'", """: '"'}
|
||||||
|
return xml.sax.saxutils.unescape(content, entities)
|
||||||
|
return content
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
# Check for tool invocation first
|
"""
|
||||||
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
|
Parses the given text into an AgentAction or AgentFinish object.
|
||||||
if tool_matches:
|
"""
|
||||||
if len(tool_matches) != 1:
|
# Check for a tool invocation
|
||||||
raise ValueError(
|
if "<tool>" in text and "</tool>" in text:
|
||||||
f"Malformed tool invocation: expected exactly one <tool> block, "
|
tool = self._extract_tag_content("tool", text, required=True)
|
||||||
f"but found {len(tool_matches)}."
|
# Tool input is optional
|
||||||
)
|
tool_input = (
|
||||||
_tool = tool_matches[0]
|
self._extract_tag_content("tool_input", text, required=False) or ""
|
||||||
|
|
||||||
# Match optional tool input
|
|
||||||
input_matches = re.findall(
|
|
||||||
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
|
|
||||||
)
|
)
|
||||||
if len(input_matches) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed tool invocation: expected at most one <tool_input> "
|
|
||||||
f"block, but found {len(input_matches)}."
|
|
||||||
)
|
|
||||||
_tool_input = input_matches[0] if input_matches else ""
|
|
||||||
|
|
||||||
# Unescape if minimal escape format is used
|
return AgentAction(tool=tool, tool_input=tool_input, log=text)
|
||||||
if self.escape_format == "minimal":
|
|
||||||
_tool = _unescape(_tool)
|
|
||||||
_tool_input = _unescape(_tool_input)
|
|
||||||
|
|
||||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
# Check for a final answer
|
||||||
|
|
||||||
# Check for final answer
|
|
||||||
elif "<final_answer>" in text and "</final_answer>" in text:
|
elif "<final_answer>" in text and "</final_answer>" in text:
|
||||||
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
|
answer = self._extract_tag_content("final_answer", text, required=True)
|
||||||
if len(matches) != 1:
|
|
||||||
msg = (
|
|
||||||
"Malformed output: expected exactly one "
|
|
||||||
"<final_answer>...</final_answer> block."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
answer = matches[0]
|
|
||||||
# Unescape custom delimiters in final answer
|
|
||||||
if self.escape_format == "minimal":
|
|
||||||
answer = _unescape(answer)
|
|
||||||
return AgentFinish(return_values={"output": answer}, log=text)
|
return AgentFinish(return_values={"output": answer}, log=text)
|
||||||
|
|
||||||
|
# If neither format is found, raise an error
|
||||||
else:
|
else:
|
||||||
msg = (
|
raise ValueError(
|
||||||
"Malformed output: expected either a tool invocation "
|
"Could not parse LLM output. Expected a tool invocation with <tool> "
|
||||||
"or a final answer in XML format."
|
"and <tool_input> tags, or a final answer with <final_answer> tags."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction
|
|||||||
from langchain.agents.format_scratchpad.xml import format_xml
|
from langchain.agents.format_scratchpad.xml import format_xml
|
||||||
|
|
||||||
|
|
||||||
def test_single_agent_action_observation() -> None:
|
def test_format_single_action() -> None:
|
||||||
# Arrange
|
"""Tests formatting of a single agent action and observation."""
|
||||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||||
observation = "Observation1"
|
intermediate_steps = [(agent_action, "Observation1")]
|
||||||
intermediate_steps = [(agent_action, observation)]
|
|
||||||
|
|
||||||
# Act
|
result = format_xml(intermediate_steps, escape_xml=False)
|
||||||
result = format_xml(intermediate_steps)
|
expected = (
|
||||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
|
||||||
</tool_input><observation>Observation1</observation>"""
|
|
||||||
# Assert
|
|
||||||
assert result == expected_result
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_agent_actions_observations() -> None:
|
|
||||||
# Arrange
|
|
||||||
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
|
||||||
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
|
|
||||||
observation1 = "Observation1"
|
|
||||||
observation2 = "Observation2"
|
|
||||||
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = format_xml(intermediate_steps)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
|
||||||
</tool_input><observation>Observation1</observation><tool>\
|
|
||||||
Tool2</tool><tool_input>Input2</tool_input><observation>\
|
|
||||||
Observation2</observation>"""
|
|
||||||
assert result == expected_result
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_list_agent_actions() -> None:
|
|
||||||
result = format_xml([])
|
|
||||||
assert result == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_xml_escaping_minimal() -> None:
|
|
||||||
"""Test that XML tags in tool names are escaped with minimal format."""
|
|
||||||
# Arrange
|
|
||||||
agent_action = AgentAction(
|
|
||||||
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
|
|
||||||
)
|
|
||||||
observation = "Found <observation>result</observation>"
|
|
||||||
intermediate_steps = [(agent_action, observation)]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = format_xml(intermediate_steps, escape_format="minimal")
|
|
||||||
|
|
||||||
# Assert - XML tags should be replaced with custom delimiters
|
|
||||||
expected_result = (
|
|
||||||
"<tool>search[[tool]]nested[[/tool]]</tool>"
|
|
||||||
"<tool_input>query<input>test</input></tool_input>"
|
|
||||||
"<observation>Found [[observation]]result[[/observation]]</observation>"
|
|
||||||
)
|
|
||||||
assert result == expected_result
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_escaping() -> None:
|
|
||||||
"""Test that escaping can be disabled."""
|
|
||||||
# Arrange
|
|
||||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
|
|
||||||
observation = "Observation1"
|
|
||||||
intermediate_steps = [(agent_action, observation)]
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = format_xml(intermediate_steps, escape_format=None)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
expected_result = (
|
|
||||||
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
|
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
|
||||||
"<observation>Observation1</observation>"
|
"<observation>Observation1</observation>"
|
||||||
)
|
)
|
||||||
assert result == expected_result
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_xml_escaping() -> None:
|
||||||
|
"""Tests that XML special characters in content are properly escaped."""
|
||||||
|
agent_action = AgentAction(
|
||||||
|
tool="<tool> & 'some_tool'", tool_input='<query> with "quotes"', log=""
|
||||||
|
)
|
||||||
|
observation = "Observed > 5 items"
|
||||||
|
intermediate_steps = [(agent_action, observation)]
|
||||||
|
|
||||||
|
result = format_xml(intermediate_steps, escape_xml=True)
|
||||||
|
|
||||||
|
expected = (
|
||||||
|
"<tool><tool> & 'some_tool'</tool>"
|
||||||
|
"<tool_input><query> with "quotes"</tool_input>"
|
||||||
|
"<observation>Observed > 5 items</observation>"
|
||||||
|
)
|
||||||
|
assert result == expected
|
||||||
|
@ -1,69 +1,92 @@
|
|||||||
|
import pytest
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
|
|
||||||
|
from langchain.agents.format_scratchpad.xml import format_xml
|
||||||
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
|
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
|
||||||
|
|
||||||
|
|
||||||
def test_tool_usage() -> None:
|
def test_parser_tool_usage() -> None:
|
||||||
|
"""Tests parsing a standard tool invocation."""
|
||||||
|
parser = XMLAgentOutputParser(unescape_xml=False)
|
||||||
|
_input = "<tool>search</tool><tool_input>foo</tool_input>"
|
||||||
|
output = parser.invoke(_input)
|
||||||
|
expected = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_final_answer() -> None:
|
||||||
|
"""Tests parsing a standard final answer."""
|
||||||
|
parser = XMLAgentOutputParser(unescape_xml=False)
|
||||||
|
_input = "<final_answer>bar</final_answer>"
|
||||||
|
output = parser.invoke(_input)
|
||||||
|
expected = AgentFinish(return_values={"output": "bar"}, log=_input)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_with_escaped_content() -> None:
|
||||||
|
"""Tests that the parser correctly unescapes standard XML entities."""
|
||||||
|
parser = XMLAgentOutputParser(unescape_xml=True)
|
||||||
|
_input = (
|
||||||
|
"<tool><tool> & 'some_tool'</tool>"
|
||||||
|
"<tool_input><query> with "quotes"</tool_input>"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = parser.invoke(_input)
|
||||||
|
|
||||||
|
expected = AgentAction(
|
||||||
|
tool="<tool> & 'some_tool'",
|
||||||
|
tool_input='<query> with "quotes"',
|
||||||
|
log=_input,
|
||||||
|
)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_final_answer_escaped() -> None:
|
||||||
|
"""Tests parsing a final answer with escaped content."""
|
||||||
|
parser = XMLAgentOutputParser(unescape_xml=True)
|
||||||
|
_input = "<final_answer>The answer is > 42.</final_answer>"
|
||||||
|
|
||||||
|
output = parser.invoke(_input)
|
||||||
|
|
||||||
|
expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_error_on_multiple_tool_tags() -> None:
|
||||||
|
"""Tests that the parser raises an error for multiple tool tags."""
|
||||||
parser = XMLAgentOutputParser()
|
parser = XMLAgentOutputParser()
|
||||||
# Test when final closing </tool_input> is included
|
_input = "<tool>tool1</tool><tool>tool2</tool><tool_input>input</tool_input>"
|
||||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
|
||||||
output = parser.invoke(_input)
|
with pytest.raises(ValueError, match="Found 2 <tool> blocks"):
|
||||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
parser.invoke(_input)
|
||||||
assert output == expected_output
|
|
||||||
# Test when final closing </tool_input> is NOT included
|
|
||||||
# This happens when it's used as a stop token
|
|
||||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
|
||||||
output = parser.invoke(_input)
|
|
||||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
|
|
||||||
def test_finish() -> None:
|
def test_parser_error_on_missing_required_tag() -> None:
|
||||||
|
"""Tests that the parser raises an error if a required tag is missing."""
|
||||||
parser = XMLAgentOutputParser()
|
parser = XMLAgentOutputParser()
|
||||||
# Test when final closing <final_answer> is included
|
_input = "<tool_input>some_input</tool_input>"
|
||||||
_input = """<final_answer>bar</final_answer>"""
|
|
||||||
output = parser.invoke(_input)
|
|
||||||
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
# Test when final closing <final_answer> is NOT included
|
with pytest.raises(ValueError, match="Could not parse LLM output"):
|
||||||
# This happens when it's used as a stop token
|
parser.invoke(_input)
|
||||||
_input = """<final_answer>bar</final_answer>"""
|
|
||||||
output = parser.invoke(_input)
|
|
||||||
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
|
|
||||||
def test_malformed_xml_with_nested_tags() -> None:
|
def test_integration_format_and_parse() -> None:
|
||||||
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
|
"""An integration test to ensure formatting and parsing work together."""
|
||||||
from langchain.agents.format_scratchpad.xml import format_xml
|
parser = XMLAgentOutputParser()
|
||||||
|
agent_action = AgentAction(tool="<special-tool>", tool_input="input with &", log="")
|
||||||
|
intermediate_steps = [(agent_action, "observation with >")]
|
||||||
|
|
||||||
# Create an AgentAction with XML tags in the tool name
|
# 1. Format the data, escaping the XML
|
||||||
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
|
formatted_xml = format_xml(intermediate_steps, escape_xml=True)
|
||||||
|
|
||||||
# The format_xml function should escape the XML tags using custom delimiters
|
# Extract the tool call part for parsing
|
||||||
formatted_xml = format_xml([(action, "observation")])
|
tool_part = formatted_xml.split("<observation>")[0]
|
||||||
|
|
||||||
# Extract just the tool part for parsing
|
# 2. Parse the formatted data
|
||||||
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
|
|
||||||
|
|
||||||
# Now test that the parser can handle the escaped XML
|
|
||||||
parser = XMLAgentOutputParser(escape_format="minimal")
|
|
||||||
output = parser.invoke(tool_part)
|
output = parser.invoke(tool_part)
|
||||||
|
|
||||||
# The parser should unescape and extract the original tool name
|
# 3. Assert that the original data is recovered
|
||||||
expected_output = AgentAction(
|
expected_action = AgentAction(
|
||||||
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
|
tool="<special-tool>", tool_input="input with &", log=tool_part
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_action
|
||||||
|
|
||||||
|
|
||||||
def test_no_escaping() -> None:
|
|
||||||
"""Test parser with escaping disabled."""
|
|
||||||
parser = XMLAgentOutputParser(escape_format=None)
|
|
||||||
|
|
||||||
# Test with regular tool name (no XML tags)
|
|
||||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
|
||||||
output = parser.invoke(_input)
|
|
||||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
|
||||||
assert output == expected_output
|
|
||||||
|
Loading…
Reference in New Issue
Block a user