From ac208f85c81febdc62894c6883823546556e600f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 17 Dec 2022 20:29:12 -0800 Subject: [PATCH] agent refactor --- docs/examples/agents/custom_agent.ipynb | 2 +- langchain/agents/agent.py | 164 ++++-------------- langchain/agents/loading.py | 17 +- langchain/agents/mrkl/base.py | 12 +- langchain/agents/mrkl/prompt.py | 2 +- langchain/agents/react/base.py | 30 ++-- langchain/agents/self_ask_with_search/base.py | 21 ++- .../agents/self_ask_with_search/prompt.py | 2 +- langchain/schema.py | 1 + tests/unit_tests/agents/test_react.py | 15 +- 10 files changed, 95 insertions(+), 171 deletions(-) diff --git a/docs/examples/agents/custom_agent.ipynb b/docs/examples/agents/custom_agent.ipynb index cc6f135bec5..fa9c3c97546 100644 --- a/docs/examples/agents/custom_agent.ipynb +++ b/docs/examples/agents/custom_agent.ipynb @@ -224,7 +224,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 99137fd4b60..538793b8127 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -2,10 +2,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Optional, Tuple +from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union from pydantic import BaseModel, root_validator +import langchain from langchain.agents.input import ChainedInput from langchain.agents.tools import Tool from langchain.chains.base import Chain @@ -14,11 +15,6 @@ from langchain.input import get_color_mapping from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate from langchain.schema import AgentAction, AgentFinish -import langchain - - -from typing import NamedTuple - class Planner(BaseModel): @@ -28,8 +24,9 @@ class Planner(BaseModel): a variable called "agent_scratchpad" where the agent can put its intermediary work. """ + llm_chain: LLMChain - return_values: List[str] + return_values: List[str] = ["output"] @abstractmethod def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: @@ -43,7 +40,9 @@ class Planner(BaseModel): def _stop(self) -> List[str]: return [f"\n{self.observation_prefix}"] - def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction: + def plan( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Union[AgentFinish, AgentAction]: """Given input, decided what to do. Args: @@ -68,11 +67,18 @@ class Planner(BaseModel): full_output += output parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output + if tool == self.finish_tool_name: + return AgentFinish({"output": tool_input}, full_output) return AgentAction(tool, tool_input, full_output) - def prepare_for_new_call(self): + def prepare_for_new_call(self) -> None: pass + @property + def finish_tool_name(self) -> str: + """Name of the tool to use to finish the chain.""" + return "Final Answer" + @property def input_keys(self) -> List[str]: """Return the input keys. @@ -101,8 +107,25 @@ class Planner(BaseModel): def llm_prefix(self) -> str: """Prefix to append the LLM call with.""" + @abstractmethod + @classmethod + def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: + """Create a prompt for this class.""" -class NewAgent(Chain, BaseModel): + @classmethod + def _validate_tools(cls, tools: List[Tool]) -> None: + """Validate that appropriate tools are passed in.""" + pass + + @classmethod + def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner: + """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) + llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) + return cls(llm_chain=llm_chain) + + +class Agent(Chain, BaseModel): planner: Planner tools: List[Tool] @@ -138,7 +161,7 @@ class NewAgent(Chain, BaseModel): [tool.name for tool in self.tools], excluded_colors=["green"] ) planner_inputs = inputs.copy() - intermediate_steps = [] + intermediate_steps: List[Tuple[AgentAction, str]] = [] # We now enter the agent loop (until it returns something). while True: # Call the LLM to see what to do. @@ -165,122 +188,3 @@ class NewAgent(Chain, BaseModel): if self.verbose: langchain.logger.log_agent_observation(observation, color=color) intermediate_steps.append((output, observation)) - - -class Agent(Chain, BaseModel, ABC): - """Agent that uses an LLM.""" - - prompt: ClassVar[BasePromptTemplate] - llm_chain: LLMChain - tools: List[Tool] - return_intermediate_steps: bool = False - input_key: str = "input" #: :meta private: - output_key: str = "output" #: :meta private: - - @property - def output_keys(self) -> List[str]: - """Return the singular output key. - - :meta private: - """ - if self.return_intermediate_steps: - return [self.output_key, "intermediate_steps"] - else: - return [self.output_key] - - @root_validator() - def validate_prompt(cls, values: Dict) -> Dict: - """Validate that prompt matches format.""" - prompt = values["llm_chain"].prompt - if "agent_scratchpad" not in prompt.input_variables: - raise ValueError( - "`agent_scratchpad` should be a variable in prompt.input_variables" - ) - return values - - @property - @abstractmethod - def observation_prefix(self) -> str: - """Prefix to append the observation with.""" - - @property - @abstractmethod - def llm_prefix(self) -> str: - """Prefix to append the LLM call with.""" - - @property - def finish_tool_name(self) -> str: - """Name of the tool to use to finish the chain.""" - return "Final Answer" - - @property - def starter_string(self) -> str: - """Put this string after user input but before first LLM call.""" - return "\n" - - - - - - @classmethod - def _validate_tools(cls, tools: List[Tool]) -> None: - """Validate that appropriate tools are passed in.""" - pass - - @classmethod - def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: - """Create a prompt for this class.""" - return cls.prompt - - def _prepare_for_new_call(self) -> None: - pass - - @classmethod - def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent: - """Construct an agent from an LLM and tools.""" - cls._validate_tools(tools) - llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) - return cls(llm_chain=llm_chain, tools=tools, **kwargs) - - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: - """Run text through and get agent response.""" - # Do any preparation necessary when receiving a new input. - self._prepare_for_new_call() - # Construct a mapping of tool name to tool for easy lookup - name_to_tool_map = {tool.name: tool.func for tool in self.tools} - # We use the ChainedInput class to iteratively add to the input over time. - chained_input = ChainedInput(self.llm_prefix, verbose=self.verbose) - # We construct a mapping from each tool to a color, used for logging. - color_mapping = get_color_mapping( - [tool.name for tool in self.tools], excluded_colors=["green"] - ) - # We now enter the agent loop (until it returns something). - while True: - # Call the LLM to see what to do. - output = self.get_action(chained_input.input, inputs) - # If the tool chosen is the finishing tool, then we end and return. - if output.tool == self.finish_tool_name: - final_output: dict = {self.output_key: output.tool_input} - if self.return_intermediate_steps: - final_output[ - "intermediate_steps" - ] = chained_input.intermediate_steps - return final_output - # Other we add the log to the Chained Input. - chained_input.add_action(output, color="green") - # And then we lookup the tool - if output.tool in name_to_tool_map: - chain = name_to_tool_map[output.tool] - # We then call the tool on the tool input to get an observation - observation = chain(output.tool_input) - color = color_mapping[output.tool] - else: - observation = f"{output.tool} is not a valid tool, try another one." - color = None - # We then log the observation - chained_input.add_observation( - observation, - self.observation_prefix, - self.llm_prefix, - color=color, - ) diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 75d24d1fbdc..3d1f9205236 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,17 +1,17 @@ """Load agent.""" from typing import Any, List -from langchain.agents.agent import Agent -from langchain.agents.mrkl.base import ZeroShotAgent -from langchain.agents.react.base import ReActDocstoreAgent -from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent +from langchain.agents.agent import Agent, Planner +from langchain.agents.mrkl.base import ZeroShotPlanner +from langchain.agents.react.base import ReActDocstorePlanner +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner from langchain.agents.tools import Tool from langchain.llms.base import LLM AGENT_TO_CLASS = { - "zero-shot-react-description": ZeroShotAgent, - "react-docstore": ReActDocstoreAgent, - "self-ask-with-search": SelfAskWithSearchAgent, + "zero-shot-react-description": ZeroShotPlanner, + "react-docstore": ReActDocstorePlanner, + "self-ask-with-search": SelfAskWithSearchPlanner, } @@ -39,4 +39,5 @@ def initialize_agent( f"Valid types are: {AGENT_TO_CLASS.keys()}." ) agent_cls = AGENT_TO_CLASS[agent] - return agent_cls.from_llm_and_tools(llm, tools, **kwargs) + planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs) + return Agent(planner=planner, tools=tools) diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 958ad53fd0a..701aaf2dc3c 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Callable, List, NamedTuple, Optional, Tuple -from langchain.agents.agent import Agent +from langchain.agents.agent import Agent, Planner from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool from langchain.llms.base import LLM @@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: return action, action_input.strip(" ").strip('"') -class ZeroShotAgent(Agent): +class ZeroShotPlanner(Planner): """Agent for the MRKL chain.""" @property @@ -101,7 +101,10 @@ class ZeroShotAgent(Agent): return get_action_and_input(text) -class MRKLChain(ZeroShotAgent): +ZeroShotAgent = ZeroShotPlanner + + +class MRKLChain(Agent): """Chain that implements the MRKL system. Example: @@ -156,4 +159,5 @@ class MRKLChain(ZeroShotAgent): Tool(name=c.action_name, func=c.action, description=c.action_description) for c in chains ] - return cls.from_llm_and_tools(llm, tools, **kwargs) + planner = ZeroShotPlanner.from_llm_and_tools(llm, tools) + return cls(planner=planner, tools=tools, **kwargs) diff --git a/langchain/agents/mrkl/prompt.py b/langchain/agents/mrkl/prompt.py index d3e911e0226..db6827b5ec7 100644 --- a/langchain/agents/mrkl/prompt.py +++ b/langchain/agents/mrkl/prompt.py @@ -13,4 +13,4 @@ Final Answer: the final answer to the original input question""" SUFFIX = """Begin! Question: {input} -{agent_scratchpad}""" +Thought:{agent_scratchpad}""" diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index ca380e1a2aa..d30cc4fc791 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -4,7 +4,7 @@ from typing import Any, ClassVar, List, Optional, Tuple from pydantic import BaseModel -from langchain.agents.agent import Agent +from langchain.agents.agent import Agent, Planner from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.tools import Tool @@ -15,10 +15,13 @@ from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate -class ReActDocstoreAgent(Agent, BaseModel): +class ReActDocstorePlanner(Planner, BaseModel): """Agent for the ReAct chin.""" - prompt: ClassVar[BasePromptTemplate] = WIKI_PROMPT + @classmethod + def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: + """Return default prompt.""" + return WIKI_PROMPT i: int = 1 @@ -72,6 +75,9 @@ class ReActDocstoreAgent(Agent, BaseModel): return f"Thought {self.i}:" +ReActDocstoreAgent = ReActDocstorePlanner + + class DocstoreExplorer: """Class to assist with exploration of a document store.""" @@ -97,12 +103,13 @@ class DocstoreExplorer: return self.document.lookup(term) -class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): +class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel): """Agent for the ReAct TextWorld chain.""" - prompt: ClassVar[BasePromptTemplate] = TEXTWORLD_PROMPT - - i: int = 1 + @classmethod + def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: + """Return default prompt.""" + return TEXTWORLD_PROMPT @classmethod def _validate_tools(cls, tools: List[Tool]) -> None: @@ -113,7 +120,10 @@ class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel): raise ValueError(f"Tool name should be Play, got {tool_names}") -class ReActChain(ReActDocstoreAgent): +ReActTextWorldAgent = ReActTextWorldPlanner + + +class ReActChain(Agent): """Chain that implements the ReAct paper. Example: @@ -130,5 +140,5 @@ class ReActChain(ReActDocstoreAgent): Tool(name="Search", func=docstore_explorer.search), Tool(name="Lookup", func=docstore_explorer.lookup), ] - llm_chain = LLMChain(llm=llm, prompt=WIKI_PROMPT) - super().__init__(llm_chain=llm_chain, tools=tools, **kwargs) + planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools) + super().__init__(planner=planner, tools=tools, **kwargs) diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index e69898aa745..5ccbfad839f 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -1,19 +1,21 @@ """Chain that does self ask with search.""" -from typing import Any, ClassVar, List, Optional, Tuple +from typing import Any, List, Optional, Tuple -from langchain.agents.agent import Agent +from langchain.agents.agent import Agent, Planner from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool -from langchain.chains.llm import LLMChain from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate from langchain.serpapi import SerpAPIWrapper -class SelfAskWithSearchAgent(Agent): +class SelfAskWithSearchPlanner(Planner): """Agent for the self-ask-with-search paper.""" - prompt: ClassVar[BasePromptTemplate] = PROMPT + @classmethod + def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: + """Prompt does not depend on tools.""" + return PROMPT @classmethod def _validate_tools(cls, tools: List[Tool]) -> None: @@ -61,7 +63,10 @@ class SelfAskWithSearchAgent(Agent): return "Are follow up questions needed here:" -class SelfAskWithSearchChain(SelfAskWithSearchAgent): +SelfAskWithSearchAgent = SelfAskWithSearchPlanner + + +class SelfAskWithSearchChain(Agent): """Chain that does self ask with search. Example: @@ -75,5 +80,5 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent): def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): """Initialize with just an LLM and a search chain.""" search_tool = Tool(name="Intermediate Answer", func=search_chain.run) - llm_chain = LLMChain(llm=llm, prompt=PROMPT) - super().__init__(llm_chain=llm_chain, tools=[search_tool], **kwargs) + planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool]) + super().__init__(planner=planner, tools=[search_tool], **kwargs) diff --git a/langchain/agents/self_ask_with_search/prompt.py b/langchain/agents/self_ask_with_search/prompt.py index e511a64b9da..c82de28dfbe 100644 --- a/langchain/agents/self_ask_with_search/prompt.py +++ b/langchain/agents/self_ask_with_search/prompt.py @@ -38,7 +38,7 @@ Intermediate answer: New Zealand. So the final answer is: No Question: {input} -{agent_scratchpad}""" +Are followup questions needed here:{agent_scratchpad}""" PROMPT = PromptTemplate( input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE ) diff --git a/langchain/schema.py b/langchain/schema.py index b620dc17313..31cf1cdc82d 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -13,6 +13,7 @@ class AgentAction(NamedTuple): class AgentFinish(NamedTuple): """Agent's return value.""" + return_values: dict log: str diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index cd937db846d..05c6298a2cd 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -10,6 +10,7 @@ from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.prompts.prompt import PromptTemplate +from langchain.schema import AgentAction _PAGE_CONTENT = """This is a page about LangChain. @@ -61,10 +62,9 @@ def test_predict_until_observation_normal() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("", {"input": ""}) - assert output.log == outputs[0] - assert output.tool == "Search" - assert output.tool_input == "foo" + output = agent.plan([], input="") + expected_output = AgentAction("Search", "foo", outputs[0]) + assert output == expected_output def test_predict_until_observation_repeat() -> None: @@ -76,10 +76,9 @@ def test_predict_until_observation_repeat() -> None: Tool("Lookup", lambda x: x), ] agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) - output = agent.get_action("", {"input": ""}) - assert output.log == "foo\nAction 1: Search[foo]" - assert output.tool == "Search" - assert output.tool_input == "foo" + output = agent.plan([], input="") + expected_output = AgentAction("Search", "foo", "foo\nAction 1: Search[foo]") + assert output == expected_output def test_react_chain() -> None: