diff --git a/libs/langchain/langchain_classic/agents/mrkl/output_parser.py b/libs/langchain/langchain_classic/agents/mrkl/output_parser.py index 1c08b68338f..4323dc32d25 100644 --- a/libs/langchain/langchain_classic/agents/mrkl/output_parser.py +++ b/libs/langchain/langchain_classic/agents/mrkl/output_parser.py @@ -41,9 +41,7 @@ class MRKLOutputParser(AgentOutputParser): OutputParserException: If the output could not be parsed. """ includes_answer = FINAL_ANSWER_ACTION in text - regex = ( - r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" - ) + regex = r"Action\s*\d*\s*:[\s]*(.*?)Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" action_match = re.search(regex, text, re.DOTALL) if action_match and includes_answer: if text.find(FINAL_ANSWER_ACTION) < text.find(action_match.group(0)): diff --git a/libs/langchain/langchain_classic/agents/output_parsers/react_single_input.py b/libs/langchain/langchain_classic/agents/output_parsers/react_single_input.py index ae7634f9636..a74cc65a164 100644 --- a/libs/langchain/langchain_classic/agents/output_parsers/react_single_input.py +++ b/libs/langchain/langchain_classic/agents/output_parsers/react_single_input.py @@ -52,9 +52,7 @@ class ReActSingleInputOutputParser(AgentOutputParser): @override def parse(self, text: str) -> AgentAction | AgentFinish: includes_answer = FINAL_ANSWER_ACTION in text - regex = ( - r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" - ) + regex = r"Action\s*\d*\s*:[\s]*(.*?)Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" action_match = re.search(regex, text, re.DOTALL) if action_match: if includes_answer: diff --git a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py index b8f3b612274..eed59a6cfa4 100644 --- a/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py +++ b/libs/langchain/tests/unit_tests/agents/output_parsers/test_react_single_input.py @@ -1,3 +1,6 @@ +import signal +import sys + import pytest from langchain_core.agents import AgentAction, AgentFinish from langchain_core.exceptions import OutputParserException @@ -43,3 +46,32 @@ Action: search Final Answer: Action Input: what is the temperature in SF?""" with pytest.raises(OutputParserException): parser.invoke(_input) + + +def _timeout_handler(_signum: int, _frame: object) -> None: + msg = "ReDoS: regex took too long" + raise TimeoutError(msg) + + +@pytest.mark.skipif( + sys.platform == "win32", reason="SIGALRM is not available on Windows" +) +def test_react_single_input_no_redos() -> None: + """Regression test for ReDoS caused by catastrophic backtracking.""" + parser = ReActSingleInputOutputParser() + malicious = "Action: " + " \t" * 1000 + "Action " + old = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(2) + try: + try: + parser.parse(malicious) + except OutputParserException: + pass + except TimeoutError: + pytest.fail( + "ReDoS detected: ReActSingleInputOutputParser.parse() " + "hung on crafted input" + ) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old) diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py b/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py index 9112a997d84..58731439dc8 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl_output_parser.py @@ -1,3 +1,6 @@ +import signal +import sys + import pytest from langchain_core.agents import AgentAction, AgentFinish from langchain_core.exceptions import OutputParserException @@ -79,3 +82,30 @@ def test_final_answer_after_parsable_action() -> None: "Parsing LLM output produced both a final answer and a parse-able action" in exception_info.value.args[0] ) + + +def _timeout_handler(_signum: int, _frame: object) -> None: + msg = "ReDoS: regex took too long" + raise TimeoutError(msg) + + +@pytest.mark.skipif( + sys.platform == "win32", reason="SIGALRM is not available on Windows" +) +def test_mrkl_output_parser_no_redos() -> None: + """Regression test for ReDoS caused by catastrophic backtracking.""" + malicious = "Action: " + " \t" * 1000 + "Action " + old = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(2) + try: + try: + mrkl_output_parser.parse(malicious) + except OutputParserException: + pass + except TimeoutError: + pytest.fail( + "ReDoS detected: MRKLOutputParser.parse() hung on crafted input" + ) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old)