diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 64f0ed71e66..2667936bbb6 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -40,7 +40,8 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: """ if FINAL_ANSWER_ACTION in llm_output: return "Final Answer", llm_output.split(FINAL_ANSWER_ACTION)[-1].strip() - regex = r"Action: (.*?)[\n]*Action Input: (.*)" + # \s matches against tab/newline/whitespace + regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" match = re.search(regex, llm_output, re.DOTALL) if not match: raise ValueError(f"Could not parse LLM output: `{llm_output}`") diff --git a/tests/unit_tests/agents/test_mrkl.py b/tests/unit_tests/agents/test_mrkl.py index 6c87f91453c..1507923879c 100644 --- a/tests/unit_tests/agents/test_mrkl.py +++ b/tests/unit_tests/agents/test_mrkl.py @@ -27,6 +27,17 @@ def test_get_action_and_input_whitespace() -> None: assert action_input == "NBA" +def test_get_action_and_input_newline() -> None: + """Test getting an action from text where Action Input is a code snippet.""" + llm_output = ( + "Now I need to write a unittest for the function.\n\n" + "Action: Python\nAction Input:\n```\nimport unittest\n\nunittest.main()\n```" + ) + action, action_input = get_action_and_input(llm_output) + assert action == "Python" + assert action_input == "```\nimport unittest\n\nunittest.main()\n```" + + def test_get_final_answer() -> None: """Test getting final answer.""" llm_output = (