diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py index 3fb73637790..77dcc648fec 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/xml.py +++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py @@ -1,52 +1,41 @@ -from typing import Literal, Optional +import xml.sax.saxutils from langchain_core.agents import AgentAction -def _escape(xml: str) -> str: - """Replace XML tags with custom safe delimiters.""" - replacements = { - "": "[[tool]]", - "": "[[/tool]]", - "": "[[tool_input]]", - "": "[[/tool_input]]", - "": "[[observation]]", - "": "[[/observation]]", - } - for orig, repl in replacements.items(): - xml = xml.replace(orig, repl) - return xml - - def format_xml( - intermediate_steps: list[tuple[AgentAction, str]], + intermediate_steps: List[Tuple[AgentAction, str]], *, - escape_format: Optional[Literal["minimal"]] = "minimal", + escape_xml: bool = True, ) -> str: """Format the intermediate steps as XML. + Escapes all special XML characters in the content, preventing injection of malicious + or malformed XML. + Args: - intermediate_steps: The intermediate steps. - escape_format: The escaping format to use. Currently only 'minimal' is - supported, which replaces XML tags with custom delimiters to prevent - conflicts. + intermediate_steps: The intermediate steps, each a tuple of + (AgentAction, observation). + escape_xml: If True, all XML special characters in the tool name, + tool input, and observation will be escaped (e.g., ``<`` becomes ``<``). Returns: - The intermediate steps as XML. + A string of concatenated XML blocks representing the intermediate steps. """ log = "" for action, observation in intermediate_steps: - if escape_format == "minimal": - # Escape XML tags in tool names and inputs using custom delimiters - tool = _escape(action.tool) - tool_input = _escape(str(action.tool_input)) - observation = _escape(str(observation)) - else: - tool = action.tool - tool_input = str(action.tool_input) - observation = str(observation) + tool = str(action.tool) + tool_input = str(action.tool_input) + observation_str = str(observation) + + if escape_xml: + entities = {"'": "'", '"': """} + tool = xml.sax.saxutils.escape(tool, entities) + tool_input = xml.sax.saxutils.escape(tool_input, entities) + observation_str = xml.sax.saxutils.escape(observation_str, entities) + log += ( f"{tool}{tool_input}" - f"{observation}" + f"{observation_str}" ) return log diff --git a/libs/langchain/langchain/agents/output_parsers/xml.py b/libs/langchain/langchain/agents/output_parsers/xml.py index bfc90f4589a..2f52ca6f182 100644 --- a/libs/langchain/langchain/agents/output_parsers/xml.py +++ b/libs/langchain/langchain/agents/output_parsers/xml.py @@ -1,5 +1,6 @@ import re -from typing import Literal, Optional, Union +import xml.sax.saxutils +from typing import Union from langchain_core.agents import AgentAction, AgentFinish from pydantic import Field @@ -7,35 +8,20 @@ from pydantic import Field from langchain.agents import AgentOutputParser -def _unescape(text: str) -> str: - """Convert custom tag delimiters back into XML tags.""" - replacements = { - "[[tool]]": "", - "[[/tool]]": "", - "[[tool_input]]": "", - "[[/tool_input]]": "", - "[[observation]]": "", - "[[/observation]]": "", - } - for repl, orig in replacements.items(): - text = text.replace(repl, orig) - return text - - class XMLAgentOutputParser(AgentOutputParser): """Parses tool invocations and final answers from XML-formatted agent output. - This parser extracts structured information from XML tags to determine whether - an agent should perform a tool action or provide a final answer. It includes - built-in escaping support to safely handle tool names and inputs - containing XML special characters. + This parser is hardened against XML injection by using standard XML entity + decoding for content within tags. It is designed to work with the corresponding + ``format_xml`` function. Args: - escape_format: The escaping format to use when parsing XML content. - Supports 'minimal' which uses custom delimiters like [[tool]] to replace - XML tags within content, preventing parsing conflicts. - Use 'minimal' if using a corresponding encoding format that uses - the _escape function when formatting the output (e.g., with format_xml). + unescape_xml: If True, the parser will unescape XML special characters in the + content of tags. This should be enabled if the agent's output was formatted + with XML escaping. + + If False, the parser will return the raw content as is, which may include + XML special characters like `<`, `>`, and `&`. Expected formats: Tool invocation (returns AgentAction): @@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser): Final answer (returns AgentFinish): The answer is 4 - Note: - Minimal escaping allows tool names containing XML tags to be safely - represented. For example, a tool named "searchnested" would be - escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically - unescaped during parsing. - Raises: ValueError: If the input doesn't match either expected XML format or contains malformed XML structure. """ - escape_format: Optional[Literal["minimal"]] = Field(default="minimal") - """The format to use for escaping XML characters. - - minimal - uses custom delimiters to replace XML tags within content, - preventing parsing conflicts. This is the only supported format currently. - - None - no escaping is applied, which may lead to parsing conflicts. + unescape_xml: bool = Field(default=True) + """If True, the parser will unescape XML special characters in the content + of tags. This should be enabled if the agent's output was formatted + with XML escaping. + + If False, the parser will return the raw content as is, + which may include XML special characters like `<`, `>`, and `&`. """ + def _extract_tag_content( + self, tag: str, text: str, *, required: bool + ) -> Union[str, None]: + """ + Extracts content from a specified XML tag, ensuring it appears at most once. + + Args: + tag: The name of the XML tag (e.g., ``'tool'``). + text: The text to parse. + required: If True, a ValueError will be raised if the tag is not found. + + Returns: + The unescaped content of the tag as a string, or None if not found + and not required. + + Raises: + ValueError: If the tag appears more than once, or if it is required + but not found. + """ + pattern = f"<{tag}>(.*?)" + matches = re.findall(pattern, text, re.DOTALL) + + if len(matches) > 1: + raise ValueError( + f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1." + ) + + if not matches: + if required: + raise ValueError(f"Malformed XML: Missing required <{tag}> block.") + return None + + content = matches[0] + if self.unescape_xml: + entities = {"'": "'", """: '"'} + return xml.sax.saxutils.unescape(content, entities) + return content + def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - # Check for tool invocation first - tool_matches = re.findall(r"(.*?)", text, re.DOTALL) - if tool_matches: - if len(tool_matches) != 1: - raise ValueError( - f"Malformed tool invocation: expected exactly one block, " - f"but found {len(tool_matches)}." - ) - _tool = tool_matches[0] - - # Match optional tool input - input_matches = re.findall( - r"(.*?)", text, re.DOTALL + """ + Parses the given text into an AgentAction or AgentFinish object. + """ + # Check for a tool invocation + if "" in text and "" in text: + tool = self._extract_tag_content("tool", text, required=True) + # Tool input is optional + tool_input = ( + self._extract_tag_content("tool_input", text, required=False) or "" ) - if len(input_matches) > 1: - raise ValueError( - f"Malformed tool invocation: expected at most one " - f"block, but found {len(input_matches)}." - ) - _tool_input = input_matches[0] if input_matches else "" - # Unescape if minimal escape format is used - if self.escape_format == "minimal": - _tool = _unescape(_tool) - _tool_input = _unescape(_tool_input) + return AgentAction(tool=tool, tool_input=tool_input, log=text) - return AgentAction(tool=_tool, tool_input=_tool_input, log=text) - - # Check for final answer + # Check for a final answer elif "" in text and "" in text: - matches = re.findall(r"(.*?)", text, re.DOTALL) - if len(matches) != 1: - msg = ( - "Malformed output: expected exactly one " - "... block." - ) - raise ValueError(msg) - answer = matches[0] - # Unescape custom delimiters in final answer - if self.escape_format == "minimal": - answer = _unescape(answer) + answer = self._extract_tag_content("final_answer", text, required=True) return AgentFinish(return_values={"output": answer}, log=text) + + # If neither format is found, raise an error else: - msg = ( - "Malformed output: expected either a tool invocation " - "or a final answer in XML format." + raise ValueError( + "Could not parse LLM output. Expected a tool invocation with " + "and tags, or a final answer with tags." ) - raise ValueError(msg) def get_format_instructions(self) -> str: raise NotImplementedError diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py index 3fdaa791a9b..72ee21e20f0 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py @@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction from langchain.agents.format_scratchpad.xml import format_xml -def test_single_agent_action_observation() -> None: - # Arrange +def test_format_single_action() -> None: + """Tests formatting of a single agent action and observation.""" agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1") - observation = "Observation1" - intermediate_steps = [(agent_action, observation)] + intermediate_steps = [(agent_action, "Observation1")] - # Act - result = format_xml(intermediate_steps) - expected_result = """Tool1Input1\ -Observation1""" - # Assert - assert result == expected_result - - -def test_multiple_agent_actions_observations() -> None: - # Arrange - agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1") - agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2") - observation1 = "Observation1" - observation2 = "Observation2" - intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)] - - # Act - result = format_xml(intermediate_steps) - - # Assert - expected_result = """Tool1Input1\ -Observation1\ -Tool2Input2\ -Observation2""" - assert result == expected_result - - -def test_empty_list_agent_actions() -> None: - result = format_xml([]) - assert result == "" - - -def test_xml_escaping_minimal() -> None: - """Test that XML tags in tool names are escaped with minimal format.""" - # Arrange - agent_action = AgentAction( - tool="searchnested", tool_input="querytest", log="" - ) - observation = "Found result" - intermediate_steps = [(agent_action, observation)] - - # Act - result = format_xml(intermediate_steps, escape_format="minimal") - - # Assert - XML tags should be replaced with custom delimiters - expected_result = ( - "search[[tool]]nested[[/tool]]" - "querytest" - "Found [[observation]]result[[/observation]]" - ) - assert result == expected_result - - -def test_no_escaping() -> None: - """Test that escaping can be disabled.""" - # Arrange - agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="") - observation = "Observation1" - intermediate_steps = [(agent_action, observation)] - - # Act - result = format_xml(intermediate_steps, escape_format=None) - - # Assert - expected_result = ( + result = format_xml(intermediate_steps, escape_xml=False) + expected = ( "Tool1Input1" "Observation1" ) - assert result == expected_result + assert result == expected + + +def test_format_xml_escaping() -> None: + """Tests that XML special characters in content are properly escaped.""" + agent_action = AgentAction( + tool=" & 'some_tool'", tool_input=' with "quotes"', log="" + ) + observation = "Observed > 5 items" + intermediate_steps = [(agent_action, observation)] + + result = format_xml(intermediate_steps, escape_xml=True) + + expected = ( + "<tool> & 'some_tool'" + "<query> with "quotes"" + "Observed > 5 items" + ) + assert result == expected diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py index 74fff2fa869..7d551d01d5f 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py @@ -1,69 +1,92 @@ +import pytest from langchain_core.agents import AgentAction, AgentFinish +from langchain.agents.format_scratchpad.xml import format_xml from langchain.agents.output_parsers.xml import XMLAgentOutputParser -def test_tool_usage() -> None: +def test_parser_tool_usage() -> None: + """Tests parsing a standard tool invocation.""" + parser = XMLAgentOutputParser(unescape_xml=False) + _input = "searchfoo" + output = parser.invoke(_input) + expected = AgentAction(tool="search", tool_input="foo", log=_input) + assert output == expected + + +def test_parser_final_answer() -> None: + """Tests parsing a standard final answer.""" + parser = XMLAgentOutputParser(unescape_xml=False) + _input = "bar" + output = parser.invoke(_input) + expected = AgentFinish(return_values={"output": "bar"}, log=_input) + assert output == expected + + +def test_parser_with_escaped_content() -> None: + """Tests that the parser correctly unescapes standard XML entities.""" + parser = XMLAgentOutputParser(unescape_xml=True) + _input = ( + "<tool> & 'some_tool'" + "<query> with "quotes"" + ) + + output = parser.invoke(_input) + + expected = AgentAction( + tool=" & 'some_tool'", + tool_input=' with "quotes"', + log=_input, + ) + assert output == expected + + +def test_parser_final_answer_escaped() -> None: + """Tests parsing a final answer with escaped content.""" + parser = XMLAgentOutputParser(unescape_xml=True) + _input = "The answer is > 42." + + output = parser.invoke(_input) + + expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input) + assert output == expected + + +def test_parser_error_on_multiple_tool_tags() -> None: + """Tests that the parser raises an error for multiple tool tags.""" parser = XMLAgentOutputParser() - # Test when final closing is included - _input = """searchfoo""" - output = parser.invoke(_input) - expected_output = AgentAction(tool="search", tool_input="foo", log=_input) - assert output == expected_output - # Test when final closing is NOT included - # This happens when it's used as a stop token - _input = """searchfoo""" - output = parser.invoke(_input) - expected_output = AgentAction(tool="search", tool_input="foo", log=_input) - assert output == expected_output + _input = "tool1tool2input" + + with pytest.raises(ValueError, match="Found 2 blocks"): + parser.invoke(_input) -def test_finish() -> None: +def test_parser_error_on_missing_required_tag() -> None: + """Tests that the parser raises an error if a required tag is missing.""" parser = XMLAgentOutputParser() - # Test when final closing is included - _input = """bar""" - output = parser.invoke(_input) - expected_output = AgentFinish(return_values={"output": "bar"}, log=_input) - assert output == expected_output + _input = "some_input" - # Test when final closing is NOT included - # This happens when it's used as a stop token - _input = """bar""" - output = parser.invoke(_input) - expected_output = AgentFinish(return_values={"output": "bar"}, log=_input) - assert output == expected_output + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.invoke(_input) -def test_malformed_xml_with_nested_tags() -> None: - """Test handling of tool names with XML tags via format_xml minimal escaping.""" - from langchain.agents.format_scratchpad.xml import format_xml +def test_integration_format_and_parse() -> None: + """An integration test to ensure formatting and parsing work together.""" + parser = XMLAgentOutputParser() + agent_action = AgentAction(tool="", tool_input="input with &", log="") + intermediate_steps = [(agent_action, "observation with >")] - # Create an AgentAction with XML tags in the tool name - action = AgentAction(tool="searchnested", tool_input="query", log="") + # 1. Format the data, escaping the XML + formatted_xml = format_xml(intermediate_steps, escape_xml=True) - # The format_xml function should escape the XML tags using custom delimiters - formatted_xml = format_xml([(action, "observation")]) + # Extract the tool call part for parsing + tool_part = formatted_xml.split("")[0] - # Extract just the tool part for parsing - tool_part = formatted_xml.split("")[0] # Remove observation part - - # Now test that the parser can handle the escaped XML - parser = XMLAgentOutputParser(escape_format="minimal") + # 2. Parse the formatted data output = parser.invoke(tool_part) - # The parser should unescape and extract the original tool name - expected_output = AgentAction( - tool="searchnested", tool_input="query", log=tool_part + # 3. Assert that the original data is recovered + expected_action = AgentAction( + tool="", tool_input="input with &", log=tool_part ) - assert output == expected_output - - -def test_no_escaping() -> None: - """Test parser with escaping disabled.""" - parser = XMLAgentOutputParser(escape_format=None) - - # Test with regular tool name (no XML tags) - _input = """searchfoo""" - output = parser.invoke(_input) - expected_output = AgentAction(tool="search", tool_input="foo", log=_input) - assert output == expected_output + assert output == expected_action