diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 2667936bbb6..e15307dfaa9 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -2,15 +2,16 @@ from __future__ import annotations import re -from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Union -from langchain.agents.agent import Agent, AgentExecutor +from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate +from langchain.schema import AgentAction, AgentFinish from langchain.tools.base import BaseTool FINAL_ANSWER_ACTION = "Final Answer:" @@ -30,6 +31,24 @@ class ChainConfig(NamedTuple): action_description: str +class ReActOutputParser(AgentOutputParser): + def parse(self, text: str) -> Union[AgentFinish, AgentAction]: + if FINAL_ANSWER_ACTION in text: + return AgentFinish( + {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, log=text + ) + # \s matches against tab/newline/whitespace + regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" + match = re.search(regex, text, re.DOTALL) + if not match: + raise ValueError(f"Could not parse LLM output: `{text}`") + action = match.group(1).strip() + action_input = match.group(2) + return AgentAction( + tool=action, tool_input=action_input.strip(" ").strip('"'), log=text + ) + + def get_action_and_input(llm_output: str) -> Tuple[str, str]: """Parse out the action and input from the LLM output. @@ -38,16 +57,11 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: The string starting with "Action:" and the following string starting with "Action Input:" should be separated by a newline. """ - if FINAL_ANSWER_ACTION in llm_output: - return "Final Answer", llm_output.split(FINAL_ANSWER_ACTION)[-1].strip() - # \s matches against tab/newline/whitespace - regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)" - match = re.search(regex, llm_output, re.DOTALL) - if not match: - raise ValueError(f"Could not parse LLM output: `{llm_output}`") - action = match.group(1).strip() - action_input = match.group(2) - return action, action_input.strip(" ").strip('"') + result = ReActOutputParser().parse(llm_output) + if isinstance(result, AgentFinish): + return "Final Answer", result.return_values["output"] + else: + return result.tool, result.tool_input class ZeroShotAgent(Agent):