mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
langchain[patch]: harden xml parser for xmloutput agent (#31859)
Harden the default implementation of the XML parser for the agent --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
3f839d566a
commit
83d8be756a
@ -1,21 +1,52 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
|
||||
|
||||
def _escape(xml: str) -> str:
|
||||
"""Replace XML tags with custom safe delimiters."""
|
||||
replacements = {
|
||||
"<tool>": "[[tool]]",
|
||||
"</tool>": "[[/tool]]",
|
||||
"<tool_input>": "[[tool_input]]",
|
||||
"</tool_input>": "[[/tool_input]]",
|
||||
"<observation>": "[[observation]]",
|
||||
"</observation>": "[[/observation]]",
|
||||
}
|
||||
for orig, repl in replacements.items():
|
||||
xml = xml.replace(orig, repl)
|
||||
return xml
|
||||
|
||||
|
||||
def format_xml(
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
*,
|
||||
escape_format: Optional[Literal["minimal"]] = "minimal",
|
||||
) -> str:
|
||||
"""Format the intermediate steps as XML.
|
||||
|
||||
Args:
|
||||
intermediate_steps: The intermediate steps.
|
||||
escape_format: The escaping format to use. Currently only 'minimal' is
|
||||
supported, which replaces XML tags with custom delimiters to prevent
|
||||
conflicts.
|
||||
|
||||
Returns:
|
||||
The intermediate steps as XML.
|
||||
"""
|
||||
log = ""
|
||||
for action, observation in intermediate_steps:
|
||||
if escape_format == "minimal":
|
||||
# Escape XML tags in tool names and inputs using custom delimiters
|
||||
tool = _escape(action.tool)
|
||||
tool_input = _escape(str(action.tool_input))
|
||||
observation = _escape(str(observation))
|
||||
else:
|
||||
tool = action.tool
|
||||
tool_input = str(action.tool_input)
|
||||
observation = str(observation)
|
||||
log += (
|
||||
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
|
||||
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
||||
f"</tool_input><observation>{observation}</observation>"
|
||||
)
|
||||
return log
|
||||
|
@ -1,47 +1,119 @@
|
||||
from typing import Union
|
||||
import re
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents import AgentOutputParser
|
||||
|
||||
|
||||
def _unescape(text: str) -> str:
|
||||
"""Convert custom tag delimiters back into XML tags."""
|
||||
replacements = {
|
||||
"[[tool]]": "<tool>",
|
||||
"[[/tool]]": "</tool>",
|
||||
"[[tool_input]]": "<tool_input>",
|
||||
"[[/tool_input]]": "</tool_input>",
|
||||
"[[observation]]": "<observation>",
|
||||
"[[/observation]]": "</observation>",
|
||||
}
|
||||
for repl, orig in replacements.items():
|
||||
text = text.replace(repl, orig)
|
||||
return text
|
||||
|
||||
|
||||
class XMLAgentOutputParser(AgentOutputParser):
|
||||
"""Parses tool invocations and final answers in XML format.
|
||||
"""Parses tool invocations and final answers from XML-formatted agent output.
|
||||
|
||||
Expects output to be in one of two formats.
|
||||
This parser extracts structured information from XML tags to determine whether
|
||||
an agent should perform a tool action or provide a final answer. It includes
|
||||
built-in escaping support to safely handle tool names and inputs
|
||||
containing XML special characters.
|
||||
|
||||
If the output signals that an action should be taken,
|
||||
should be in the below format. This will result in an AgentAction
|
||||
being returned.
|
||||
Args:
|
||||
escape_format: The escaping format to use when parsing XML content.
|
||||
Supports 'minimal' which uses custom delimiters like [[tool]] to replace
|
||||
XML tags within content, preventing parsing conflicts.
|
||||
Use 'minimal' if using a corresponding encoding format that uses
|
||||
the _escape function when formatting the output (e.g., with format_xml).
|
||||
|
||||
```
|
||||
<tool>search</tool>
|
||||
<tool_input>what is 2 + 2</tool_input>
|
||||
```
|
||||
Expected formats:
|
||||
Tool invocation (returns AgentAction):
|
||||
<tool>search</tool>
|
||||
<tool_input>what is 2 + 2</tool_input>
|
||||
|
||||
If the output signals that a final answer should be given,
|
||||
should be in the below format. This will result in an AgentFinish
|
||||
being returned.
|
||||
Final answer (returns AgentFinish):
|
||||
<final_answer>The answer is 4</final_answer>
|
||||
|
||||
```
|
||||
<final_answer>Foo</final_answer>
|
||||
```
|
||||
Note:
|
||||
Minimal escaping allows tool names containing XML tags to be safely
|
||||
represented. For example, a tool named "search<tool>nested</tool>" would be
|
||||
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
|
||||
unescaped during parsing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input doesn't match either expected XML format or
|
||||
contains malformed XML structure.
|
||||
"""
|
||||
|
||||
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
|
||||
"""The format to use for escaping XML characters.
|
||||
|
||||
minimal - uses custom delimiters to replace XML tags within content,
|
||||
preventing parsing conflicts. This is the only supported format currently.
|
||||
|
||||
None - no escaping is applied, which may lead to parsing conflicts.
|
||||
"""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
if "</tool>" in text:
|
||||
tool, tool_input = text.split("</tool>")
|
||||
_tool = tool.split("<tool>")[1]
|
||||
_tool_input = tool_input.split("<tool_input>")[1]
|
||||
if "</tool_input>" in _tool_input:
|
||||
_tool_input = _tool_input.split("</tool_input>")[0]
|
||||
# Check for tool invocation first
|
||||
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
|
||||
if tool_matches:
|
||||
if len(tool_matches) != 1:
|
||||
msg = (
|
||||
f"Malformed tool invocation: expected exactly one <tool> block, "
|
||||
f"but found {len(tool_matches)}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
_tool = tool_matches[0]
|
||||
|
||||
# Match optional tool input
|
||||
input_matches = re.findall(
|
||||
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
|
||||
)
|
||||
if len(input_matches) > 1:
|
||||
msg = (
|
||||
f"Malformed tool invocation: expected at most one <tool_input> "
|
||||
f"block, but found {len(input_matches)}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
_tool_input = input_matches[0] if input_matches else ""
|
||||
|
||||
# Unescape if minimal escape format is used
|
||||
if self.escape_format == "minimal":
|
||||
_tool = _unescape(_tool)
|
||||
_tool_input = _unescape(_tool_input)
|
||||
|
||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||
if "<final_answer>" in text:
|
||||
_, answer = text.split("<final_answer>")
|
||||
if "</final_answer>" in answer:
|
||||
answer = answer.split("</final_answer>")[0]
|
||||
# Check for final answer
|
||||
if "<final_answer>" in text and "</final_answer>" in text:
|
||||
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
|
||||
if len(matches) != 1:
|
||||
msg = (
|
||||
"Malformed output: expected exactly one "
|
||||
"<final_answer>...</final_answer> block."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
answer = matches[0]
|
||||
# Unescape custom delimiters in final answer
|
||||
if self.escape_format == "minimal":
|
||||
answer = _unescape(answer)
|
||||
return AgentFinish(return_values={"output": answer}, log=text)
|
||||
raise ValueError
|
||||
msg = (
|
||||
"Malformed output: expected either a tool invocation "
|
||||
"or a final answer in XML format."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
@ -39,3 +39,42 @@ Observation2</observation>"""
|
||||
def test_empty_list_agent_actions() -> None:
|
||||
result = format_xml([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_xml_escaping_minimal() -> None:
|
||||
"""Test that XML tags in tool names are escaped with minimal format."""
|
||||
# Arrange
|
||||
agent_action = AgentAction(
|
||||
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
|
||||
)
|
||||
observation = "Found <observation>result</observation>"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps, escape_format="minimal")
|
||||
|
||||
# Assert - XML tags should be replaced with custom delimiters
|
||||
expected_result = (
|
||||
"<tool>search[[tool]]nested[[/tool]]</tool>"
|
||||
"<tool_input>query<input>test</input></tool_input>"
|
||||
"<observation>Found [[observation]]result[[/observation]]</observation>"
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_no_escaping() -> None:
|
||||
"""Test that escaping can be disabled."""
|
||||
# Arrange
|
||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
|
||||
observation = "Observation1"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps, escape_format=None)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
|
||||
"<observation>Observation1</observation>"
|
||||
)
|
||||
assert result == expected_result
|
||||
|
@ -32,3 +32,38 @@ def test_finish() -> None:
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_malformed_xml_with_nested_tags() -> None:
|
||||
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
|
||||
# Create an AgentAction with XML tags in the tool name
|
||||
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
|
||||
|
||||
# The format_xml function should escape the XML tags using custom delimiters
|
||||
formatted_xml = format_xml([(action, "observation")])
|
||||
|
||||
# Extract just the tool part for parsing
|
||||
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
|
||||
|
||||
# Now test that the parser can handle the escaped XML
|
||||
parser = XMLAgentOutputParser(escape_format="minimal")
|
||||
output = parser.invoke(tool_part)
|
||||
|
||||
# The parser should unescape and extract the original tool name
|
||||
expected_output = AgentAction(
|
||||
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_no_escaping() -> None:
|
||||
"""Test parser with escaping disabled."""
|
||||
parser = XMLAgentOutputParser(escape_format=None)
|
||||
|
||||
# Test with regular tool name (no XML tags)
|
||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user