diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py
index 3fb73637790..77dcc648fec 100644
--- a/libs/langchain/langchain/agents/format_scratchpad/xml.py
+++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py
@@ -1,52 +1,41 @@
-from typing import Literal, Optional
+import xml.sax.saxutils
from langchain_core.agents import AgentAction
-def _escape(xml: str) -> str:
- """Replace XML tags with custom safe delimiters."""
- replacements = {
- "": "[[tool]]",
- "": "[[/tool]]",
- "": "[[tool_input]]",
- "": "[[/tool_input]]",
- "": "[[observation]]",
- "": "[[/observation]]",
- }
- for orig, repl in replacements.items():
- xml = xml.replace(orig, repl)
- return xml
-
-
def format_xml(
- intermediate_steps: list[tuple[AgentAction, str]],
+ intermediate_steps: List[Tuple[AgentAction, str]],
*,
- escape_format: Optional[Literal["minimal"]] = "minimal",
+ escape_xml: bool = True,
) -> str:
"""Format the intermediate steps as XML.
+ Escapes all special XML characters in the content, preventing injection of malicious
+ or malformed XML.
+
Args:
- intermediate_steps: The intermediate steps.
- escape_format: The escaping format to use. Currently only 'minimal' is
- supported, which replaces XML tags with custom delimiters to prevent
- conflicts.
+ intermediate_steps: The intermediate steps, each a tuple of
+ (AgentAction, observation).
+ escape_xml: If True, all XML special characters in the tool name,
+ tool input, and observation will be escaped (e.g., ``<`` becomes ``<``).
Returns:
- The intermediate steps as XML.
+ A string of concatenated XML blocks representing the intermediate steps.
"""
log = ""
for action, observation in intermediate_steps:
- if escape_format == "minimal":
- # Escape XML tags in tool names and inputs using custom delimiters
- tool = _escape(action.tool)
- tool_input = _escape(str(action.tool_input))
- observation = _escape(str(observation))
- else:
- tool = action.tool
- tool_input = str(action.tool_input)
- observation = str(observation)
+ tool = str(action.tool)
+ tool_input = str(action.tool_input)
+ observation_str = str(observation)
+
+ if escape_xml:
+ entities = {"'": "'", '"': """}
+ tool = xml.sax.saxutils.escape(tool, entities)
+ tool_input = xml.sax.saxutils.escape(tool_input, entities)
+ observation_str = xml.sax.saxutils.escape(observation_str, entities)
+
log += (
f"{tool}{tool_input}"
- f"{observation}"
+ f"{observation_str}"
)
return log
diff --git a/libs/langchain/langchain/agents/output_parsers/xml.py b/libs/langchain/langchain/agents/output_parsers/xml.py
index bfc90f4589a..2f52ca6f182 100644
--- a/libs/langchain/langchain/agents/output_parsers/xml.py
+++ b/libs/langchain/langchain/agents/output_parsers/xml.py
@@ -1,5 +1,6 @@
import re
-from typing import Literal, Optional, Union
+import xml.sax.saxutils
+from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from pydantic import Field
@@ -7,35 +8,20 @@ from pydantic import Field
from langchain.agents import AgentOutputParser
-def _unescape(text: str) -> str:
- """Convert custom tag delimiters back into XML tags."""
- replacements = {
- "[[tool]]": "",
- "[[/tool]]": "",
- "[[tool_input]]": "",
- "[[/tool_input]]": "",
- "[[observation]]": "",
- "[[/observation]]": "",
- }
- for repl, orig in replacements.items():
- text = text.replace(repl, orig)
- return text
-
-
class XMLAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers from XML-formatted agent output.
- This parser extracts structured information from XML tags to determine whether
- an agent should perform a tool action or provide a final answer. It includes
- built-in escaping support to safely handle tool names and inputs
- containing XML special characters.
+ This parser is hardened against XML injection by using standard XML entity
+ decoding for content within tags. It is designed to work with the corresponding
+ ``format_xml`` function.
Args:
- escape_format: The escaping format to use when parsing XML content.
- Supports 'minimal' which uses custom delimiters like [[tool]] to replace
- XML tags within content, preventing parsing conflicts.
- Use 'minimal' if using a corresponding encoding format that uses
- the _escape function when formatting the output (e.g., with format_xml).
+ unescape_xml: If True, the parser will unescape XML special characters in the
+ content of tags. This should be enabled if the agent's output was formatted
+ with XML escaping.
+
+ If False, the parser will return the raw content as is, which may include
+ XML special characters like `<`, `>`, and `&`.
Expected formats:
Tool invocation (returns AgentAction):
@@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser):
Final answer (returns AgentFinish):
The answer is 4
- Note:
- Minimal escaping allows tool names containing XML tags to be safely
- represented. For example, a tool named "searchnested" would be
- escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
- unescaped during parsing.
-
Raises:
ValueError: If the input doesn't match either expected XML format or
contains malformed XML structure.
"""
- escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
- """The format to use for escaping XML characters.
-
- minimal - uses custom delimiters to replace XML tags within content,
- preventing parsing conflicts. This is the only supported format currently.
-
- None - no escaping is applied, which may lead to parsing conflicts.
+ unescape_xml: bool = Field(default=True)
+ """If True, the parser will unescape XML special characters in the content
+ of tags. This should be enabled if the agent's output was formatted
+ with XML escaping.
+
+ If False, the parser will return the raw content as is,
+ which may include XML special characters like `<`, `>`, and `&`.
"""
+ def _extract_tag_content(
+ self, tag: str, text: str, *, required: bool
+ ) -> Union[str, None]:
+ """
+ Extracts content from a specified XML tag, ensuring it appears at most once.
+
+ Args:
+ tag: The name of the XML tag (e.g., ``'tool'``).
+ text: The text to parse.
+ required: If True, a ValueError will be raised if the tag is not found.
+
+ Returns:
+ The unescaped content of the tag as a string, or None if not found
+ and not required.
+
+ Raises:
+ ValueError: If the tag appears more than once, or if it is required
+ but not found.
+ """
+ pattern = f"<{tag}>(.*?){tag}>"
+ matches = re.findall(pattern, text, re.DOTALL)
+
+ if len(matches) > 1:
+ raise ValueError(
+ f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
+ )
+
+ if not matches:
+ if required:
+ raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
+ return None
+
+ content = matches[0]
+ if self.unescape_xml:
+ entities = {"'": "'", """: '"'}
+ return xml.sax.saxutils.unescape(content, entities)
+ return content
+
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
- # Check for tool invocation first
- tool_matches = re.findall(r"(.*?)", text, re.DOTALL)
- if tool_matches:
- if len(tool_matches) != 1:
- raise ValueError(
- f"Malformed tool invocation: expected exactly one block, "
- f"but found {len(tool_matches)}."
- )
- _tool = tool_matches[0]
-
- # Match optional tool input
- input_matches = re.findall(
- r"(.*?)", text, re.DOTALL
+ """
+ Parses the given text into an AgentAction or AgentFinish object.
+ """
+ # Check for a tool invocation
+ if "" in text and "" in text:
+ tool = self._extract_tag_content("tool", text, required=True)
+ # Tool input is optional
+ tool_input = (
+ self._extract_tag_content("tool_input", text, required=False) or ""
)
- if len(input_matches) > 1:
- raise ValueError(
- f"Malformed tool invocation: expected at most one "
- f"block, but found {len(input_matches)}."
- )
- _tool_input = input_matches[0] if input_matches else ""
- # Unescape if minimal escape format is used
- if self.escape_format == "minimal":
- _tool = _unescape(_tool)
- _tool_input = _unescape(_tool_input)
+ return AgentAction(tool=tool, tool_input=tool_input, log=text)
- return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
-
- # Check for final answer
+ # Check for a final answer
elif "" in text and "" in text:
- matches = re.findall(r"(.*?)", text, re.DOTALL)
- if len(matches) != 1:
- msg = (
- "Malformed output: expected exactly one "
- "... block."
- )
- raise ValueError(msg)
- answer = matches[0]
- # Unescape custom delimiters in final answer
- if self.escape_format == "minimal":
- answer = _unescape(answer)
+ answer = self._extract_tag_content("final_answer", text, required=True)
return AgentFinish(return_values={"output": answer}, log=text)
+
+ # If neither format is found, raise an error
else:
- msg = (
- "Malformed output: expected either a tool invocation "
- "or a final answer in XML format."
+ raise ValueError(
+ "Could not parse LLM output. Expected a tool invocation with "
+ "and tags, or a final answer with tags."
)
- raise ValueError(msg)
def get_format_instructions(self) -> str:
raise NotImplementedError
diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py
index 3fdaa791a9b..72ee21e20f0 100644
--- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py
+++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py
@@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction
from langchain.agents.format_scratchpad.xml import format_xml
-def test_single_agent_action_observation() -> None:
- # Arrange
+def test_format_single_action() -> None:
+ """Tests formatting of a single agent action and observation."""
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
- observation = "Observation1"
- intermediate_steps = [(agent_action, observation)]
+ intermediate_steps = [(agent_action, "Observation1")]
- # Act
- result = format_xml(intermediate_steps)
- expected_result = """Tool1Input1\
-Observation1"""
- # Assert
- assert result == expected_result
-
-
-def test_multiple_agent_actions_observations() -> None:
- # Arrange
- agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
- agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
- observation1 = "Observation1"
- observation2 = "Observation2"
- intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
-
- # Act
- result = format_xml(intermediate_steps)
-
- # Assert
- expected_result = """Tool1Input1\
-Observation1\
-Tool2Input2\
-Observation2"""
- assert result == expected_result
-
-
-def test_empty_list_agent_actions() -> None:
- result = format_xml([])
- assert result == ""
-
-
-def test_xml_escaping_minimal() -> None:
- """Test that XML tags in tool names are escaped with minimal format."""
- # Arrange
- agent_action = AgentAction(
- tool="searchnested", tool_input="querytest", log=""
- )
- observation = "Found result"
- intermediate_steps = [(agent_action, observation)]
-
- # Act
- result = format_xml(intermediate_steps, escape_format="minimal")
-
- # Assert - XML tags should be replaced with custom delimiters
- expected_result = (
- "search[[tool]]nested[[/tool]]"
- "querytest"
- "Found [[observation]]result[[/observation]]"
- )
- assert result == expected_result
-
-
-def test_no_escaping() -> None:
- """Test that escaping can be disabled."""
- # Arrange
- agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
- observation = "Observation1"
- intermediate_steps = [(agent_action, observation)]
-
- # Act
- result = format_xml(intermediate_steps, escape_format=None)
-
- # Assert
- expected_result = (
+ result = format_xml(intermediate_steps, escape_xml=False)
+ expected = (
"Tool1Input1"
"Observation1"
)
- assert result == expected_result
+ assert result == expected
+
+
+def test_format_xml_escaping() -> None:
+ """Tests that XML special characters in content are properly escaped."""
+ agent_action = AgentAction(
+ tool=" & 'some_tool'", tool_input=' with "quotes"', log=""
+ )
+ observation = "Observed > 5 items"
+ intermediate_steps = [(agent_action, observation)]
+
+ result = format_xml(intermediate_steps, escape_xml=True)
+
+ expected = (
+ "<tool> & 'some_tool'"
+ "<query> with "quotes""
+ "Observed > 5 items"
+ )
+ assert result == expected
diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py
index 74fff2fa869..7d551d01d5f 100644
--- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py
+++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py
@@ -1,69 +1,92 @@
+import pytest
from langchain_core.agents import AgentAction, AgentFinish
+from langchain.agents.format_scratchpad.xml import format_xml
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
-def test_tool_usage() -> None:
+def test_parser_tool_usage() -> None:
+ """Tests parsing a standard tool invocation."""
+ parser = XMLAgentOutputParser(unescape_xml=False)
+ _input = "searchfoo"
+ output = parser.invoke(_input)
+ expected = AgentAction(tool="search", tool_input="foo", log=_input)
+ assert output == expected
+
+
+def test_parser_final_answer() -> None:
+ """Tests parsing a standard final answer."""
+ parser = XMLAgentOutputParser(unescape_xml=False)
+ _input = "bar"
+ output = parser.invoke(_input)
+ expected = AgentFinish(return_values={"output": "bar"}, log=_input)
+ assert output == expected
+
+
+def test_parser_with_escaped_content() -> None:
+ """Tests that the parser correctly unescapes standard XML entities."""
+ parser = XMLAgentOutputParser(unescape_xml=True)
+ _input = (
+ "<tool> & 'some_tool'"
+ "<query> with "quotes""
+ )
+
+ output = parser.invoke(_input)
+
+ expected = AgentAction(
+ tool=" & 'some_tool'",
+ tool_input=' with "quotes"',
+ log=_input,
+ )
+ assert output == expected
+
+
+def test_parser_final_answer_escaped() -> None:
+ """Tests parsing a final answer with escaped content."""
+ parser = XMLAgentOutputParser(unescape_xml=True)
+ _input = "The answer is > 42."
+
+ output = parser.invoke(_input)
+
+ expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
+ assert output == expected
+
+
+def test_parser_error_on_multiple_tool_tags() -> None:
+ """Tests that the parser raises an error for multiple tool tags."""
parser = XMLAgentOutputParser()
- # Test when final closing is included
- _input = """searchfoo"""
- output = parser.invoke(_input)
- expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
- assert output == expected_output
- # Test when final closing is NOT included
- # This happens when it's used as a stop token
- _input = """searchfoo"""
- output = parser.invoke(_input)
- expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
- assert output == expected_output
+ _input = "tool1tool2input"
+
+ with pytest.raises(ValueError, match="Found 2 blocks"):
+ parser.invoke(_input)
-def test_finish() -> None:
+def test_parser_error_on_missing_required_tag() -> None:
+ """Tests that the parser raises an error if a required tag is missing."""
parser = XMLAgentOutputParser()
- # Test when final closing is included
- _input = """bar"""
- output = parser.invoke(_input)
- expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
- assert output == expected_output
+ _input = "some_input"
- # Test when final closing is NOT included
- # This happens when it's used as a stop token
- _input = """bar"""
- output = parser.invoke(_input)
- expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
- assert output == expected_output
+ with pytest.raises(ValueError, match="Could not parse LLM output"):
+ parser.invoke(_input)
-def test_malformed_xml_with_nested_tags() -> None:
- """Test handling of tool names with XML tags via format_xml minimal escaping."""
- from langchain.agents.format_scratchpad.xml import format_xml
+def test_integration_format_and_parse() -> None:
+ """An integration test to ensure formatting and parsing work together."""
+ parser = XMLAgentOutputParser()
+ agent_action = AgentAction(tool="", tool_input="input with &", log="")
+ intermediate_steps = [(agent_action, "observation with >")]
- # Create an AgentAction with XML tags in the tool name
- action = AgentAction(tool="searchnested", tool_input="query", log="")
+ # 1. Format the data, escaping the XML
+ formatted_xml = format_xml(intermediate_steps, escape_xml=True)
- # The format_xml function should escape the XML tags using custom delimiters
- formatted_xml = format_xml([(action, "observation")])
+ # Extract the tool call part for parsing
+ tool_part = formatted_xml.split("")[0]
- # Extract just the tool part for parsing
- tool_part = formatted_xml.split("")[0] # Remove observation part
-
- # Now test that the parser can handle the escaped XML
- parser = XMLAgentOutputParser(escape_format="minimal")
+ # 2. Parse the formatted data
output = parser.invoke(tool_part)
- # The parser should unescape and extract the original tool name
- expected_output = AgentAction(
- tool="searchnested", tool_input="query", log=tool_part
+ # 3. Assert that the original data is recovered
+ expected_action = AgentAction(
+ tool="", tool_input="input with &", log=tool_part
)
- assert output == expected_output
-
-
-def test_no_escaping() -> None:
- """Test parser with escaping disabled."""
- parser = XMLAgentOutputParser(escape_format=None)
-
- # Test with regular tool name (no XML tags)
- _input = """searchfoo"""
- output = parser.invoke(_input)
- expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
- assert output == expected_output
+ assert output == expected_action