mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
refactor: streamline XML formatting and parsing functions for improved security and clarity
This commit is contained in:
parent
87f00e84ad
commit
2a2c9cd5ba
@ -1,52 +1,41 @@
|
||||
from typing import Literal, Optional
|
||||
import xml.sax.saxutils
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
|
||||
|
||||
def _escape(xml: str) -> str:
|
||||
"""Replace XML tags with custom safe delimiters."""
|
||||
replacements = {
|
||||
"<tool>": "[[tool]]",
|
||||
"</tool>": "[[/tool]]",
|
||||
"<tool_input>": "[[tool_input]]",
|
||||
"</tool_input>": "[[/tool_input]]",
|
||||
"<observation>": "[[observation]]",
|
||||
"</observation>": "[[/observation]]",
|
||||
}
|
||||
for orig, repl in replacements.items():
|
||||
xml = xml.replace(orig, repl)
|
||||
return xml
|
||||
|
||||
|
||||
def format_xml(
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
*,
|
||||
escape_format: Optional[Literal["minimal"]] = "minimal",
|
||||
escape_xml: bool = True,
|
||||
) -> str:
|
||||
"""Format the intermediate steps as XML.
|
||||
|
||||
Escapes all special XML characters in the content, preventing injection of malicious
|
||||
or malformed XML.
|
||||
|
||||
Args:
|
||||
intermediate_steps: The intermediate steps.
|
||||
escape_format: The escaping format to use. Currently only 'minimal' is
|
||||
supported, which replaces XML tags with custom delimiters to prevent
|
||||
conflicts.
|
||||
intermediate_steps: The intermediate steps, each a tuple of
|
||||
(AgentAction, observation).
|
||||
escape_xml: If True, all XML special characters in the tool name,
|
||||
tool input, and observation will be escaped (e.g., ``<`` becomes ``<``).
|
||||
|
||||
Returns:
|
||||
The intermediate steps as XML.
|
||||
A string of concatenated XML blocks representing the intermediate steps.
|
||||
"""
|
||||
log = ""
|
||||
for action, observation in intermediate_steps:
|
||||
if escape_format == "minimal":
|
||||
# Escape XML tags in tool names and inputs using custom delimiters
|
||||
tool = _escape(action.tool)
|
||||
tool_input = _escape(str(action.tool_input))
|
||||
observation = _escape(str(observation))
|
||||
else:
|
||||
tool = action.tool
|
||||
tool_input = str(action.tool_input)
|
||||
observation = str(observation)
|
||||
tool = str(action.tool)
|
||||
tool_input = str(action.tool_input)
|
||||
observation_str = str(observation)
|
||||
|
||||
if escape_xml:
|
||||
entities = {"'": "'", '"': """}
|
||||
tool = xml.sax.saxutils.escape(tool, entities)
|
||||
tool_input = xml.sax.saxutils.escape(tool_input, entities)
|
||||
observation_str = xml.sax.saxutils.escape(observation_str, entities)
|
||||
|
||||
log += (
|
||||
f"<tool>{tool}</tool><tool_input>{tool_input}"
|
||||
f"</tool_input><observation>{observation}</observation>"
|
||||
f"</tool_input><observation>{observation_str}</observation>"
|
||||
)
|
||||
return log
|
||||
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Literal, Optional, Union
|
||||
import xml.sax.saxutils
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from pydantic import Field
|
||||
@ -7,35 +8,20 @@ from pydantic import Field
|
||||
from langchain.agents import AgentOutputParser
|
||||
|
||||
|
||||
def _unescape(text: str) -> str:
|
||||
"""Convert custom tag delimiters back into XML tags."""
|
||||
replacements = {
|
||||
"[[tool]]": "<tool>",
|
||||
"[[/tool]]": "</tool>",
|
||||
"[[tool_input]]": "<tool_input>",
|
||||
"[[/tool_input]]": "</tool_input>",
|
||||
"[[observation]]": "<observation>",
|
||||
"[[/observation]]": "</observation>",
|
||||
}
|
||||
for repl, orig in replacements.items():
|
||||
text = text.replace(repl, orig)
|
||||
return text
|
||||
|
||||
|
||||
class XMLAgentOutputParser(AgentOutputParser):
|
||||
"""Parses tool invocations and final answers from XML-formatted agent output.
|
||||
|
||||
This parser extracts structured information from XML tags to determine whether
|
||||
an agent should perform a tool action or provide a final answer. It includes
|
||||
built-in escaping support to safely handle tool names and inputs
|
||||
containing XML special characters.
|
||||
This parser is hardened against XML injection by using standard XML entity
|
||||
decoding for content within tags. It is designed to work with the corresponding
|
||||
``format_xml`` function.
|
||||
|
||||
Args:
|
||||
escape_format: The escaping format to use when parsing XML content.
|
||||
Supports 'minimal' which uses custom delimiters like [[tool]] to replace
|
||||
XML tags within content, preventing parsing conflicts.
|
||||
Use 'minimal' if using a corresponding encoding format that uses
|
||||
the _escape function when formatting the output (e.g., with format_xml).
|
||||
unescape_xml: If True, the parser will unescape XML special characters in the
|
||||
content of tags. This should be enabled if the agent's output was formatted
|
||||
with XML escaping.
|
||||
|
||||
If False, the parser will return the raw content as is, which may include
|
||||
XML special characters like `<`, `>`, and `&`.
|
||||
|
||||
Expected formats:
|
||||
Tool invocation (returns AgentAction):
|
||||
@ -45,75 +31,83 @@ class XMLAgentOutputParser(AgentOutputParser):
|
||||
Final answer (returns AgentFinish):
|
||||
<final_answer>The answer is 4</final_answer>
|
||||
|
||||
Note:
|
||||
Minimal escaping allows tool names containing XML tags to be safely
|
||||
represented. For example, a tool named "search<tool>nested</tool>" would be
|
||||
escaped as "search[[tool]]nested[[/tool]]" in the XML and automatically
|
||||
unescaped during parsing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input doesn't match either expected XML format or
|
||||
contains malformed XML structure.
|
||||
"""
|
||||
|
||||
escape_format: Optional[Literal["minimal"]] = Field(default="minimal")
|
||||
"""The format to use for escaping XML characters.
|
||||
|
||||
minimal - uses custom delimiters to replace XML tags within content,
|
||||
preventing parsing conflicts. This is the only supported format currently.
|
||||
|
||||
None - no escaping is applied, which may lead to parsing conflicts.
|
||||
unescape_xml: bool = Field(default=True)
|
||||
"""If True, the parser will unescape XML special characters in the content
|
||||
of tags. This should be enabled if the agent's output was formatted
|
||||
with XML escaping.
|
||||
|
||||
If False, the parser will return the raw content as is,
|
||||
which may include XML special characters like `<`, `>`, and `&`.
|
||||
"""
|
||||
|
||||
def _extract_tag_content(
|
||||
self, tag: str, text: str, *, required: bool
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extracts content from a specified XML tag, ensuring it appears at most once.
|
||||
|
||||
Args:
|
||||
tag: The name of the XML tag (e.g., ``'tool'``).
|
||||
text: The text to parse.
|
||||
required: If True, a ValueError will be raised if the tag is not found.
|
||||
|
||||
Returns:
|
||||
The unescaped content of the tag as a string, or None if not found
|
||||
and not required.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tag appears more than once, or if it is required
|
||||
but not found.
|
||||
"""
|
||||
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
if len(matches) > 1:
|
||||
raise ValueError(
|
||||
f"Malformed XML: Found {len(matches)} <{tag}> blocks. Expected 0 or 1."
|
||||
)
|
||||
|
||||
if not matches:
|
||||
if required:
|
||||
raise ValueError(f"Malformed XML: Missing required <{tag}> block.")
|
||||
return None
|
||||
|
||||
content = matches[0]
|
||||
if self.unescape_xml:
|
||||
entities = {"'": "'", """: '"'}
|
||||
return xml.sax.saxutils.unescape(content, entities)
|
||||
return content
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
# Check for tool invocation first
|
||||
tool_matches = re.findall(r"<tool>(.*?)</tool>", text, re.DOTALL)
|
||||
if tool_matches:
|
||||
if len(tool_matches) != 1:
|
||||
raise ValueError(
|
||||
f"Malformed tool invocation: expected exactly one <tool> block, "
|
||||
f"but found {len(tool_matches)}."
|
||||
)
|
||||
_tool = tool_matches[0]
|
||||
|
||||
# Match optional tool input
|
||||
input_matches = re.findall(
|
||||
r"<tool_input>(.*?)</tool_input>", text, re.DOTALL
|
||||
"""
|
||||
Parses the given text into an AgentAction or AgentFinish object.
|
||||
"""
|
||||
# Check for a tool invocation
|
||||
if "<tool>" in text and "</tool>" in text:
|
||||
tool = self._extract_tag_content("tool", text, required=True)
|
||||
# Tool input is optional
|
||||
tool_input = (
|
||||
self._extract_tag_content("tool_input", text, required=False) or ""
|
||||
)
|
||||
if len(input_matches) > 1:
|
||||
raise ValueError(
|
||||
f"Malformed tool invocation: expected at most one <tool_input> "
|
||||
f"block, but found {len(input_matches)}."
|
||||
)
|
||||
_tool_input = input_matches[0] if input_matches else ""
|
||||
|
||||
# Unescape if minimal escape format is used
|
||||
if self.escape_format == "minimal":
|
||||
_tool = _unescape(_tool)
|
||||
_tool_input = _unescape(_tool_input)
|
||||
return AgentAction(tool=tool, tool_input=tool_input, log=text)
|
||||
|
||||
return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
|
||||
|
||||
# Check for final answer
|
||||
# Check for a final answer
|
||||
elif "<final_answer>" in text and "</final_answer>" in text:
|
||||
matches = re.findall(r"<final_answer>(.*?)</final_answer>", text, re.DOTALL)
|
||||
if len(matches) != 1:
|
||||
msg = (
|
||||
"Malformed output: expected exactly one "
|
||||
"<final_answer>...</final_answer> block."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
answer = matches[0]
|
||||
# Unescape custom delimiters in final answer
|
||||
if self.escape_format == "minimal":
|
||||
answer = _unescape(answer)
|
||||
answer = self._extract_tag_content("final_answer", text, required=True)
|
||||
return AgentFinish(return_values={"output": answer}, log=text)
|
||||
|
||||
# If neither format is found, raise an error
|
||||
else:
|
||||
msg = (
|
||||
"Malformed output: expected either a tool invocation "
|
||||
"or a final answer in XML format."
|
||||
raise ValueError(
|
||||
"Could not parse LLM output. Expected a tool invocation with <tool> "
|
||||
"and <tool_input> tags, or a final answer with <final_answer> tags."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
@ -3,78 +3,32 @@ from langchain_core.agents import AgentAction
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
|
||||
|
||||
def test_single_agent_action_observation() -> None:
|
||||
# Arrange
|
||||
def test_format_single_action() -> None:
|
||||
"""Tests formatting of a single agent action and observation."""
|
||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||
observation = "Observation1"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
intermediate_steps = [(agent_action, "Observation1")]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps)
|
||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||
</tool_input><observation>Observation1</observation>"""
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_multiple_agent_actions_observations() -> None:
|
||||
# Arrange
|
||||
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
|
||||
observation1 = "Observation1"
|
||||
observation2 = "Observation2"
|
||||
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps)
|
||||
|
||||
# Assert
|
||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||
</tool_input><observation>Observation1</observation><tool>\
|
||||
Tool2</tool><tool_input>Input2</tool_input><observation>\
|
||||
Observation2</observation>"""
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_empty_list_agent_actions() -> None:
|
||||
result = format_xml([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_xml_escaping_minimal() -> None:
|
||||
"""Test that XML tags in tool names are escaped with minimal format."""
|
||||
# Arrange
|
||||
agent_action = AgentAction(
|
||||
tool="search<tool>nested</tool>", tool_input="query<input>test</input>", log=""
|
||||
)
|
||||
observation = "Found <observation>result</observation>"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps, escape_format="minimal")
|
||||
|
||||
# Assert - XML tags should be replaced with custom delimiters
|
||||
expected_result = (
|
||||
"<tool>search[[tool]]nested[[/tool]]</tool>"
|
||||
"<tool_input>query<input>test</input></tool_input>"
|
||||
"<observation>Found [[observation]]result[[/observation]]</observation>"
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_no_escaping() -> None:
|
||||
"""Test that escaping can be disabled."""
|
||||
# Arrange
|
||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="")
|
||||
observation = "Observation1"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps, escape_format=None)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
result = format_xml(intermediate_steps, escape_xml=False)
|
||||
expected = (
|
||||
"<tool>Tool1</tool><tool_input>Input1</tool_input>"
|
||||
"<observation>Observation1</observation>"
|
||||
)
|
||||
assert result == expected_result
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_format_xml_escaping() -> None:
|
||||
"""Tests that XML special characters in content are properly escaped."""
|
||||
agent_action = AgentAction(
|
||||
tool="<tool> & 'some_tool'", tool_input='<query> with "quotes"', log=""
|
||||
)
|
||||
observation = "Observed > 5 items"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
result = format_xml(intermediate_steps, escape_xml=True)
|
||||
|
||||
expected = (
|
||||
"<tool><tool> & 'some_tool'</tool>"
|
||||
"<tool_input><query> with "quotes"</tool_input>"
|
||||
"<observation>Observed > 5 items</observation>"
|
||||
)
|
||||
assert result == expected
|
||||
|
@ -1,69 +1,92 @@
|
||||
import pytest
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
|
||||
|
||||
|
||||
def test_tool_usage() -> None:
|
||||
def test_parser_tool_usage() -> None:
|
||||
"""Tests parsing a standard tool invocation."""
|
||||
parser = XMLAgentOutputParser(unescape_xml=False)
|
||||
_input = "<tool>search</tool><tool_input>foo</tool_input>"
|
||||
output = parser.invoke(_input)
|
||||
expected = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_parser_final_answer() -> None:
|
||||
"""Tests parsing a standard final answer."""
|
||||
parser = XMLAgentOutputParser(unescape_xml=False)
|
||||
_input = "<final_answer>bar</final_answer>"
|
||||
output = parser.invoke(_input)
|
||||
expected = AgentFinish(return_values={"output": "bar"}, log=_input)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_parser_with_escaped_content() -> None:
|
||||
"""Tests that the parser correctly unescapes standard XML entities."""
|
||||
parser = XMLAgentOutputParser(unescape_xml=True)
|
||||
_input = (
|
||||
"<tool><tool> & 'some_tool'</tool>"
|
||||
"<tool_input><query> with "quotes"</tool_input>"
|
||||
)
|
||||
|
||||
output = parser.invoke(_input)
|
||||
|
||||
expected = AgentAction(
|
||||
tool="<tool> & 'some_tool'",
|
||||
tool_input='<query> with "quotes"',
|
||||
log=_input,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_parser_final_answer_escaped() -> None:
|
||||
"""Tests parsing a final answer with escaped content."""
|
||||
parser = XMLAgentOutputParser(unescape_xml=True)
|
||||
_input = "<final_answer>The answer is > 42.</final_answer>"
|
||||
|
||||
output = parser.invoke(_input)
|
||||
|
||||
expected = AgentFinish(return_values={"output": "The answer is > 42."}, log=_input)
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_parser_error_on_multiple_tool_tags() -> None:
|
||||
"""Tests that the parser raises an error for multiple tool tags."""
|
||||
parser = XMLAgentOutputParser()
|
||||
# Test when final closing </tool_input> is included
|
||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||
assert output == expected_output
|
||||
# Test when final closing </tool_input> is NOT included
|
||||
# This happens when it's used as a stop token
|
||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||
assert output == expected_output
|
||||
_input = "<tool>tool1</tool><tool>tool2</tool><tool_input>input</tool_input>"
|
||||
|
||||
with pytest.raises(ValueError, match="Found 2 <tool> blocks"):
|
||||
parser.invoke(_input)
|
||||
|
||||
|
||||
def test_finish() -> None:
|
||||
def test_parser_error_on_missing_required_tag() -> None:
|
||||
"""Tests that the parser raises an error if a required tag is missing."""
|
||||
parser = XMLAgentOutputParser()
|
||||
# Test when final closing <final_answer> is included
|
||||
_input = """<final_answer>bar</final_answer>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
|
||||
assert output == expected_output
|
||||
_input = "<tool_input>some_input</tool_input>"
|
||||
|
||||
# Test when final closing <final_answer> is NOT included
|
||||
# This happens when it's used as a stop token
|
||||
_input = """<final_answer>bar</final_answer>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentFinish(return_values={"output": "bar"}, log=_input)
|
||||
assert output == expected_output
|
||||
with pytest.raises(ValueError, match="Could not parse LLM output"):
|
||||
parser.invoke(_input)
|
||||
|
||||
|
||||
def test_malformed_xml_with_nested_tags() -> None:
|
||||
"""Test handling of tool names with XML tags via format_xml minimal escaping."""
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
def test_integration_format_and_parse() -> None:
|
||||
"""An integration test to ensure formatting and parsing work together."""
|
||||
parser = XMLAgentOutputParser()
|
||||
agent_action = AgentAction(tool="<special-tool>", tool_input="input with &", log="")
|
||||
intermediate_steps = [(agent_action, "observation with >")]
|
||||
|
||||
# Create an AgentAction with XML tags in the tool name
|
||||
action = AgentAction(tool="search<tool>nested</tool>", tool_input="query", log="")
|
||||
# 1. Format the data, escaping the XML
|
||||
formatted_xml = format_xml(intermediate_steps, escape_xml=True)
|
||||
|
||||
# The format_xml function should escape the XML tags using custom delimiters
|
||||
formatted_xml = format_xml([(action, "observation")])
|
||||
# Extract the tool call part for parsing
|
||||
tool_part = formatted_xml.split("<observation>")[0]
|
||||
|
||||
# Extract just the tool part for parsing
|
||||
tool_part = formatted_xml.split("<observation>")[0] # Remove observation part
|
||||
|
||||
# Now test that the parser can handle the escaped XML
|
||||
parser = XMLAgentOutputParser(escape_format="minimal")
|
||||
# 2. Parse the formatted data
|
||||
output = parser.invoke(tool_part)
|
||||
|
||||
# The parser should unescape and extract the original tool name
|
||||
expected_output = AgentAction(
|
||||
tool="search<tool>nested</tool>", tool_input="query", log=tool_part
|
||||
# 3. Assert that the original data is recovered
|
||||
expected_action = AgentAction(
|
||||
tool="<special-tool>", tool_input="input with &", log=tool_part
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_no_escaping() -> None:
|
||||
"""Test parser with escaping disabled."""
|
||||
parser = XMLAgentOutputParser(escape_format=None)
|
||||
|
||||
# Test with regular tool name (no XML tags)
|
||||
_input = """<tool>search</tool><tool_input>foo</tool_input>"""
|
||||
output = parser.invoke(_input)
|
||||
expected_output = AgentAction(tool="search", tool_input="foo", log=_input)
|
||||
assert output == expected_output
|
||||
assert output == expected_action
|
||||
|
Loading…
Reference in New Issue
Block a user