Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
fc580fbca3 mrkl parser 2023-03-30 23:28:57 -07:00

View File

@@ -2,15 +2,16 @@
from __future__ import annotations
import re
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.schema import AgentAction, AgentFinish
from langchain.tools.base import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:"
@@ -30,6 +31,24 @@ class ChainConfig(NamedTuple):
action_description: str
class ReActOutputParser(AgentOutputParser):
def parse(self, text: str) -> Union[AgentFinish, AgentAction]:
if FINAL_ANSWER_ACTION in text:
return AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, log=text
)
# \s matches against tab/newline/whitespace
regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)"
match = re.search(regex, text, re.DOTALL)
if not match:
raise ValueError(f"Could not parse LLM output: `{text}`")
action = match.group(1).strip()
action_input = match.group(2)
return AgentAction(
tool=action, tool_input=action_input.strip(" ").strip('"'), log=text
)
def get_action_and_input(llm_output: str) -> Tuple[str, str]:
"""Parse out the action and input from the LLM output.
@@ -38,16 +57,11 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
The string starting with "Action:" and the following string starting
with "Action Input:" should be separated by a newline.
"""
if FINAL_ANSWER_ACTION in llm_output:
return "Final Answer", llm_output.split(FINAL_ANSWER_ACTION)[-1].strip()
# \s matches against tab/newline/whitespace
regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
raise ValueError(f"Could not parse LLM output: `{llm_output}`")
action = match.group(1).strip()
action_input = match.group(2)
return action, action_input.strip(" ").strip('"')
result = ReActOutputParser().parse(llm_output)
if isinstance(result, AgentFinish):
return "Final Answer", result.return_values["output"]
else:
return result.tool, result.tool_input
class ZeroShotAgent(Agent):