diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py index e1e94509ef3..3fb73637790 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/xml.py +++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py @@ -1,21 +1,52 @@ +from typing import Literal, Optional + from langchain_core.agents import AgentAction +def _escape(xml: str) -> str: + """Replace XML tags with custom safe delimiters.""" + replacements = { + "": "[[tool]]", + "": "[[/tool]]", + "": "[[tool_input]]", + "": "[[/tool_input]]", + "": "[[observation]]", + "": "[[/observation]]", + } + for orig, repl in replacements.items(): + xml = xml.replace(orig, repl) + return xml + + def format_xml( intermediate_steps: list[tuple[AgentAction, str]], + *, + escape_format: Optional[Literal["minimal"]] = "minimal", ) -> str: """Format the intermediate steps as XML. Args: intermediate_steps: The intermediate steps. + escape_format: The escaping format to use. Currently only 'minimal' is + supported, which replaces XML tags with custom delimiters to prevent + conflicts. Returns: The intermediate steps as XML. """ log = "" for action, observation in intermediate_steps: + if escape_format == "minimal": + # Escape XML tags in tool names and inputs using custom delimiters + tool = _escape(action.tool) + tool_input = _escape(str(action.tool_input)) + observation = _escape(str(observation)) + else: + tool = action.tool + tool_input = str(action.tool_input) + observation = str(observation) log += ( - f"{action.tool}{action.tool_input}" + f"{tool}{tool_input}" f"{observation}" ) return log diff --git a/libs/langchain/langchain/agents/output_parsers/xml.py b/libs/langchain/langchain/agents/output_parsers/xml.py index e8a7587b1cd..7d3e73c7ba9 100644 --- a/libs/langchain/langchain/agents/output_parsers/xml.py +++ b/libs/langchain/langchain/agents/output_parsers/xml.py @@ -1,47 +1,119 @@ -from typing import Union +import re +from typing import Literal, Optional, Union from langchain_core.agents import AgentAction, AgentFinish +from pydantic import Field from langchain.agents import AgentOutputParser +def _unescape(text: str) -> str: + """Convert custom tag delimiters back into XML tags.""" + replacements = { + "[[tool]]": "", + "[[/tool]]": "", + "[[tool_input]]": "", + "[[/tool_input]]": "", + "[[observation]]": "", + "[[/observation]]": "", + } + for repl, orig in replacements.items(): + text = text.replace(repl, orig) + return text + + class XMLAgentOutputParser(AgentOutputParser): - """Parses tool invocations and final answers in XML format. + """Parses tool invocations and final answers from XML-formatted agent output. - Expects output to be in one of two formats. + This parser extracts structured information from XML tags to determine whether + an agent should perform a tool action or provide a final answer. It includes + built-in escaping support to safely handle tool names and inputs + containing XML special characters. - If the output signals that an action should be taken, - should be in the below format. This will result in an AgentAction - being returned. + Args: + escape_format: The escaping format to use when parsing XML content. + Supports 'minimal' which uses custom delimiters like [[tool]] to replace + XML tags within content, preventing parsing conflicts. + Use 'minimal' if using a corresponding encoding format that uses + the _escape function when formatting the output (e.g., with format_xml). - ``` - search - what is 2 + 2 - ``` + Expected formats: + Tool invocation (returns AgentAction): + search + what is 2 + 2 - If the output signals that a final answer should be given, - should be in the below format. This will result in an AgentFinish - being returned. + Final answer (returns AgentFinish): + The answer is 4 - ``` - Foo - ``` + Note: + Minimal escaping allows tool names containing XML tags to be safely + represented. For example, a tool named "searchnested" would be + escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically + unescaped during parsing. + + Raises: + ValueError: If the input doesn't match either expected XML format or + contains malformed XML structure. + """ + + escape_format: Optional[Literal["minimal"]] = Field(default="minimal") + """The format to use for escaping XML characters. + + minimal - uses custom delimiters to replace XML tags within content, + preventing parsing conflicts. This is the only supported format currently. + + None - no escaping is applied, which may lead to parsing conflicts. """ def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - if "" in text: - tool, tool_input = text.split("") - _tool = tool.split("")[1] - _tool_input = tool_input.split("")[1] - if "" in _tool_input: - _tool_input = _tool_input.split("")[0] + # Check for tool invocation first + tool_matches = re.findall(r"(.*?)", text, re.DOTALL) + if tool_matches: + if len(tool_matches) != 1: + msg = ( + f"Malformed tool invocation: expected exactly one block, " + f"but found {len(tool_matches)}." + ) + raise ValueError(msg) + _tool = tool_matches[0] + + # Match optional tool input + input_matches = re.findall( + r"(.*?)", text, re.DOTALL + ) + if len(input_matches) > 1: + msg = ( + f"Malformed tool invocation: expected at most one " + f"block, but found {len(input_matches)}." + ) + raise ValueError(msg) + _tool_input = input_matches[0] if input_matches else "" + + # Unescape if minimal escape format is used + if self.escape_format == "minimal": + _tool = _unescape(_tool) + _tool_input = _unescape(_tool_input) + return AgentAction(tool=_tool, tool_input=_tool_input, log=text) - if "" in text: - _, answer = text.split("") - if "" in answer: - answer = answer.split("")[0] + # Check for final answer + if "" in text and "" in text: + matches = re.findall(r"(.*?)", text, re.DOTALL) + if len(matches) != 1: + msg = ( + "Malformed output: expected exactly one " + "... block." + ) + raise ValueError(msg) + answer = matches[0] + # Unescape custom delimiters in final answer + if self.escape_format == "minimal": + answer = _unescape(answer) return AgentFinish(return_values={"output": answer}, log=text) - raise ValueError + msg = ( + "Malformed output: expected either a tool invocation " + "or a final answer in XML format." + ) + raise ValueError(msg) def get_format_instructions(self) -> str: raise NotImplementedError diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py index 322b6874bef..3fdaa791a9b 100644 --- a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_xml.py @@ -39,3 +39,42 @@ Observation2""" def test_empty_list_agent_actions() -> None: result = format_xml([]) assert result == "" + + +def test_xml_escaping_minimal() -> None: + """Test that XML tags in tool names are escaped with minimal format.""" + # Arrange + agent_action = AgentAction( + tool="searchnested", tool_input="querytest", log="" + ) + observation = "Found result" + intermediate_steps = [(agent_action, observation)] + + # Act + result = format_xml(intermediate_steps, escape_format="minimal") + + # Assert - XML tags should be replaced with custom delimiters + expected_result = ( + "search[[tool]]nested[[/tool]]" + "querytest" + "Found [[observation]]result[[/observation]]" + ) + assert result == expected_result + + +def test_no_escaping() -> None: + """Test that escaping can be disabled.""" + # Arrange + agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="") + observation = "Observation1" + intermediate_steps = [(agent_action, observation)] + + # Act + result = format_xml(intermediate_steps, escape_format=None) + + # Assert + expected_result = ( + "Tool1Input1" + "Observation1" + ) + assert result == expected_result diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py index 1b271e24700..74fff2fa869 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_xml.py @@ -32,3 +32,38 @@ def test_finish() -> None: output = parser.invoke(_input) expected_output = AgentFinish(return_values={"output": "bar"}, log=_input) assert output == expected_output + + +def test_malformed_xml_with_nested_tags() -> None: + """Test handling of tool names with XML tags via format_xml minimal escaping.""" + from langchain.agents.format_scratchpad.xml import format_xml + + # Create an AgentAction with XML tags in the tool name + action = AgentAction(tool="searchnested", tool_input="query", log="") + + # The format_xml function should escape the XML tags using custom delimiters + formatted_xml = format_xml([(action, "observation")]) + + # Extract just the tool part for parsing + tool_part = formatted_xml.split("")[0] # Remove observation part + + # Now test that the parser can handle the escaped XML + parser = XMLAgentOutputParser(escape_format="minimal") + output = parser.invoke(tool_part) + + # The parser should unescape and extract the original tool name + expected_output = AgentAction( + tool="searchnested", tool_input="query", log=tool_part + ) + assert output == expected_output + + +def test_no_escaping() -> None: + """Test parser with escaping disabled.""" + parser = XMLAgentOutputParser(escape_format=None) + + # Test with regular tool name (no XML tags) + _input = """searchfoo""" + output = parser.invoke(_input) + expected_output = AgentAction(tool="search", tool_input="foo", log=_input) + assert output == expected_output