mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 20:15:40 +00:00
agent refactor
This commit is contained in:
parent
85e7c5fd6c
commit
ac208f85c8
@ -224,7 +224,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.6"
|
"version": "3.10.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple
|
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
import langchain
|
||||||
from langchain.agents.input import ChainedInput
|
from langchain.agents.input import ChainedInput
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -14,11 +15,6 @@ from langchain.input import get_color_mapping
|
|||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
import langchain
|
|
||||||
|
|
||||||
|
|
||||||
from typing import NamedTuple
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Planner(BaseModel):
|
class Planner(BaseModel):
|
||||||
@ -28,8 +24,9 @@ class Planner(BaseModel):
|
|||||||
a variable called "agent_scratchpad" where the agent can put its
|
a variable called "agent_scratchpad" where the agent can put its
|
||||||
intermediary work.
|
intermediary work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
return_values: List[str]
|
return_values: List[str] = ["output"]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
||||||
@ -43,7 +40,9 @@ class Planner(BaseModel):
|
|||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> List[str]:
|
||||||
return [f"\n{self.observation_prefix}"]
|
return [f"\n{self.observation_prefix}"]
|
||||||
|
|
||||||
def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction:
|
def plan(
|
||||||
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||||
|
) -> Union[AgentFinish, AgentAction]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -68,11 +67,18 @@ class Planner(BaseModel):
|
|||||||
full_output += output
|
full_output += output
|
||||||
parsed_output = self._extract_tool_and_input(full_output)
|
parsed_output = self._extract_tool_and_input(full_output)
|
||||||
tool, tool_input = parsed_output
|
tool, tool_input = parsed_output
|
||||||
|
if tool == self.finish_tool_name:
|
||||||
|
return AgentFinish({"output": tool_input}, full_output)
|
||||||
return AgentAction(tool, tool_input, full_output)
|
return AgentAction(tool, tool_input, full_output)
|
||||||
|
|
||||||
def prepare_for_new_call(self):
|
def prepare_for_new_call(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finish_tool_name(self) -> str:
|
||||||
|
"""Name of the tool to use to finish the chain."""
|
||||||
|
return "Final Answer"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
@ -101,8 +107,25 @@ class Planner(BaseModel):
|
|||||||
def llm_prefix(self) -> str:
|
def llm_prefix(self) -> str:
|
||||||
"""Prefix to append the LLM call with."""
|
"""Prefix to append the LLM call with."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
@classmethod
|
||||||
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||||
|
"""Create a prompt for this class."""
|
||||||
|
|
||||||
class NewAgent(Chain, BaseModel):
|
@classmethod
|
||||||
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
||||||
|
"""Validate that appropriate tools are passed in."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner:
|
||||||
|
"""Construct an agent from an LLM and tools."""
|
||||||
|
cls._validate_tools(tools)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||||
|
return cls(llm_chain=llm_chain)
|
||||||
|
|
||||||
|
|
||||||
|
class Agent(Chain, BaseModel):
|
||||||
|
|
||||||
planner: Planner
|
planner: Planner
|
||||||
tools: List[Tool]
|
tools: List[Tool]
|
||||||
@ -138,7 +161,7 @@ class NewAgent(Chain, BaseModel):
|
|||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||||
)
|
)
|
||||||
planner_inputs = inputs.copy()
|
planner_inputs = inputs.copy()
|
||||||
intermediate_steps = []
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||||
# We now enter the agent loop (until it returns something).
|
# We now enter the agent loop (until it returns something).
|
||||||
while True:
|
while True:
|
||||||
# Call the LLM to see what to do.
|
# Call the LLM to see what to do.
|
||||||
@ -165,122 +188,3 @@ class NewAgent(Chain, BaseModel):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
langchain.logger.log_agent_observation(observation, color=color)
|
langchain.logger.log_agent_observation(observation, color=color)
|
||||||
intermediate_steps.append((output, observation))
|
intermediate_steps.append((output, observation))
|
||||||
|
|
||||||
|
|
||||||
class Agent(Chain, BaseModel, ABC):
|
|
||||||
"""Agent that uses an LLM."""
|
|
||||||
|
|
||||||
prompt: ClassVar[BasePromptTemplate]
|
|
||||||
llm_chain: LLMChain
|
|
||||||
tools: List[Tool]
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
input_key: str = "input" #: :meta private:
|
|
||||||
output_key: str = "output" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the singular output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
return [self.output_key, "intermediate_steps"]
|
|
||||||
else:
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_prompt(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate that prompt matches format."""
|
|
||||||
prompt = values["llm_chain"].prompt
|
|
||||||
if "agent_scratchpad" not in prompt.input_variables:
|
|
||||||
raise ValueError(
|
|
||||||
"`agent_scratchpad` should be a variable in prompt.input_variables"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def observation_prefix(self) -> str:
|
|
||||||
"""Prefix to append the observation with."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def llm_prefix(self) -> str:
|
|
||||||
"""Prefix to append the LLM call with."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def finish_tool_name(self) -> str:
|
|
||||||
"""Name of the tool to use to finish the chain."""
|
|
||||||
return "Final Answer"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def starter_string(self) -> str:
|
|
||||||
"""Put this string after user input but before first LLM call."""
|
|
||||||
return "\n"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_tools(cls, tools: List[Tool]) -> None:
|
|
||||||
"""Validate that appropriate tools are passed in."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
|
||||||
"""Create a prompt for this class."""
|
|
||||||
return cls.prompt
|
|
||||||
|
|
||||||
def _prepare_for_new_call(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
|
|
||||||
"""Construct an agent from an LLM and tools."""
|
|
||||||
cls._validate_tools(tools)
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
|
||||||
return cls(llm_chain=llm_chain, tools=tools, **kwargs)
|
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
|
||||||
"""Run text through and get agent response."""
|
|
||||||
# Do any preparation necessary when receiving a new input.
|
|
||||||
self._prepare_for_new_call()
|
|
||||||
# Construct a mapping of tool name to tool for easy lookup
|
|
||||||
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
|
||||||
# We use the ChainedInput class to iteratively add to the input over time.
|
|
||||||
chained_input = ChainedInput(self.llm_prefix, verbose=self.verbose)
|
|
||||||
# We construct a mapping from each tool to a color, used for logging.
|
|
||||||
color_mapping = get_color_mapping(
|
|
||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
|
||||||
)
|
|
||||||
# We now enter the agent loop (until it returns something).
|
|
||||||
while True:
|
|
||||||
# Call the LLM to see what to do.
|
|
||||||
output = self.get_action(chained_input.input, inputs)
|
|
||||||
# If the tool chosen is the finishing tool, then we end and return.
|
|
||||||
if output.tool == self.finish_tool_name:
|
|
||||||
final_output: dict = {self.output_key: output.tool_input}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
final_output[
|
|
||||||
"intermediate_steps"
|
|
||||||
] = chained_input.intermediate_steps
|
|
||||||
return final_output
|
|
||||||
# Other we add the log to the Chained Input.
|
|
||||||
chained_input.add_action(output, color="green")
|
|
||||||
# And then we lookup the tool
|
|
||||||
if output.tool in name_to_tool_map:
|
|
||||||
chain = name_to_tool_map[output.tool]
|
|
||||||
# We then call the tool on the tool input to get an observation
|
|
||||||
observation = chain(output.tool_input)
|
|
||||||
color = color_mapping[output.tool]
|
|
||||||
else:
|
|
||||||
observation = f"{output.tool} is not a valid tool, try another one."
|
|
||||||
color = None
|
|
||||||
# We then log the observation
|
|
||||||
chained_input.add_observation(
|
|
||||||
observation,
|
|
||||||
self.observation_prefix,
|
|
||||||
self.llm_prefix,
|
|
||||||
color=color,
|
|
||||||
)
|
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
"""Load agent."""
|
"""Load agent."""
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent, Planner
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
from langchain.agents.mrkl.base import ZeroShotPlanner
|
||||||
from langchain.agents.react.base import ReActDocstoreAgent
|
from langchain.agents.react.base import ReActDocstorePlanner
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
AGENT_TO_CLASS = {
|
AGENT_TO_CLASS = {
|
||||||
"zero-shot-react-description": ZeroShotAgent,
|
"zero-shot-react-description": ZeroShotPlanner,
|
||||||
"react-docstore": ReActDocstoreAgent,
|
"react-docstore": ReActDocstorePlanner,
|
||||||
"self-ask-with-search": SelfAskWithSearchAgent,
|
"self-ask-with-search": SelfAskWithSearchPlanner,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -39,4 +39,5 @@ def initialize_agent(
|
|||||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||||
)
|
)
|
||||||
agent_cls = AGENT_TO_CLASS[agent]
|
agent_cls = AGENT_TO_CLASS[agent]
|
||||||
return agent_cls.from_llm_and_tools(llm, tools, **kwargs)
|
planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs)
|
||||||
|
return Agent(planner=planner, tools=tools)
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent, Planner
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
|
|||||||
return action, action_input.strip(" ").strip('"')
|
return action, action_input.strip(" ").strip('"')
|
||||||
|
|
||||||
|
|
||||||
class ZeroShotAgent(Agent):
|
class ZeroShotPlanner(Planner):
|
||||||
"""Agent for the MRKL chain."""
|
"""Agent for the MRKL chain."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -101,7 +101,10 @@ class ZeroShotAgent(Agent):
|
|||||||
return get_action_and_input(text)
|
return get_action_and_input(text)
|
||||||
|
|
||||||
|
|
||||||
class MRKLChain(ZeroShotAgent):
|
ZeroShotAgent = ZeroShotPlanner
|
||||||
|
|
||||||
|
|
||||||
|
class MRKLChain(Agent):
|
||||||
"""Chain that implements the MRKL system.
|
"""Chain that implements the MRKL system.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -156,4 +159,5 @@ class MRKLChain(ZeroShotAgent):
|
|||||||
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
||||||
for c in chains
|
for c in chains
|
||||||
]
|
]
|
||||||
return cls.from_llm_and_tools(llm, tools, **kwargs)
|
planner = ZeroShotPlanner.from_llm_and_tools(llm, tools)
|
||||||
|
return cls(planner=planner, tools=tools, **kwargs)
|
||||||
|
@ -13,4 +13,4 @@ Final Answer: the final answer to the original input question"""
|
|||||||
SUFFIX = """Begin!
|
SUFFIX = """Begin!
|
||||||
|
|
||||||
Question: {input}
|
Question: {input}
|
||||||
{agent_scratchpad}"""
|
Thought:{agent_scratchpad}"""
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, ClassVar, List, Optional, Tuple
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent, Planner
|
||||||
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
||||||
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
@ -15,10 +15,13 @@ from langchain.llms.base import LLM
|
|||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class ReActDocstoreAgent(Agent, BaseModel):
|
class ReActDocstorePlanner(Planner, BaseModel):
|
||||||
"""Agent for the ReAct chin."""
|
"""Agent for the ReAct chin."""
|
||||||
|
|
||||||
prompt: ClassVar[BasePromptTemplate] = WIKI_PROMPT
|
@classmethod
|
||||||
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||||
|
"""Return default prompt."""
|
||||||
|
return WIKI_PROMPT
|
||||||
|
|
||||||
i: int = 1
|
i: int = 1
|
||||||
|
|
||||||
@ -72,6 +75,9 @@ class ReActDocstoreAgent(Agent, BaseModel):
|
|||||||
return f"Thought {self.i}:"
|
return f"Thought {self.i}:"
|
||||||
|
|
||||||
|
|
||||||
|
ReActDocstoreAgent = ReActDocstorePlanner
|
||||||
|
|
||||||
|
|
||||||
class DocstoreExplorer:
|
class DocstoreExplorer:
|
||||||
"""Class to assist with exploration of a document store."""
|
"""Class to assist with exploration of a document store."""
|
||||||
|
|
||||||
@ -97,12 +103,13 @@ class DocstoreExplorer:
|
|||||||
return self.document.lookup(term)
|
return self.document.lookup(term)
|
||||||
|
|
||||||
|
|
||||||
class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
|
class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
|
||||||
"""Agent for the ReAct TextWorld chain."""
|
"""Agent for the ReAct TextWorld chain."""
|
||||||
|
|
||||||
prompt: ClassVar[BasePromptTemplate] = TEXTWORLD_PROMPT
|
@classmethod
|
||||||
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||||
i: int = 1
|
"""Return default prompt."""
|
||||||
|
return TEXTWORLD_PROMPT
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: List[Tool]) -> None:
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
||||||
@ -113,7 +120,10 @@ class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
|
|||||||
raise ValueError(f"Tool name should be Play, got {tool_names}")
|
raise ValueError(f"Tool name should be Play, got {tool_names}")
|
||||||
|
|
||||||
|
|
||||||
class ReActChain(ReActDocstoreAgent):
|
ReActTextWorldAgent = ReActTextWorldPlanner
|
||||||
|
|
||||||
|
|
||||||
|
class ReActChain(Agent):
|
||||||
"""Chain that implements the ReAct paper.
|
"""Chain that implements the ReAct paper.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -130,5 +140,5 @@ class ReActChain(ReActDocstoreAgent):
|
|||||||
Tool(name="Search", func=docstore_explorer.search),
|
Tool(name="Search", func=docstore_explorer.search),
|
||||||
Tool(name="Lookup", func=docstore_explorer.lookup),
|
Tool(name="Lookup", func=docstore_explorer.lookup),
|
||||||
]
|
]
|
||||||
llm_chain = LLMChain(llm=llm, prompt=WIKI_PROMPT)
|
planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools)
|
||||||
super().__init__(llm_chain=llm_chain, tools=tools, **kwargs)
|
super().__init__(planner=planner, tools=tools, **kwargs)
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
"""Chain that does self ask with search."""
|
"""Chain that does self ask with search."""
|
||||||
from typing import Any, ClassVar, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent, Planner
|
||||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.serpapi import SerpAPIWrapper
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchAgent(Agent):
|
class SelfAskWithSearchPlanner(Planner):
|
||||||
"""Agent for the self-ask-with-search paper."""
|
"""Agent for the self-ask-with-search paper."""
|
||||||
|
|
||||||
prompt: ClassVar[BasePromptTemplate] = PROMPT
|
@classmethod
|
||||||
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||||
|
"""Prompt does not depend on tools."""
|
||||||
|
return PROMPT
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: List[Tool]) -> None:
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
||||||
@ -61,7 +63,10 @@ class SelfAskWithSearchAgent(Agent):
|
|||||||
return "Are follow up questions needed here:"
|
return "Are follow up questions needed here:"
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
SelfAskWithSearchAgent = SelfAskWithSearchPlanner
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAskWithSearchChain(Agent):
|
||||||
"""Chain that does self ask with search.
|
"""Chain that does self ask with search.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -75,5 +80,5 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
|||||||
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||||
"""Initialize with just an LLM and a search chain."""
|
"""Initialize with just an LLM and a search chain."""
|
||||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool])
|
||||||
super().__init__(llm_chain=llm_chain, tools=[search_tool], **kwargs)
|
super().__init__(planner=planner, tools=[search_tool], **kwargs)
|
||||||
|
@ -38,7 +38,7 @@ Intermediate answer: New Zealand.
|
|||||||
So the final answer is: No
|
So the final answer is: No
|
||||||
|
|
||||||
Question: {input}
|
Question: {input}
|
||||||
{agent_scratchpad}"""
|
Are followup questions needed here:{agent_scratchpad}"""
|
||||||
PROMPT = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE
|
input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE
|
||||||
)
|
)
|
||||||
|
@ -13,6 +13,7 @@ class AgentAction(NamedTuple):
|
|||||||
|
|
||||||
class AgentFinish(NamedTuple):
|
class AgentFinish(NamedTuple):
|
||||||
"""Agent's return value."""
|
"""Agent's return value."""
|
||||||
|
|
||||||
return_values: dict
|
return_values: dict
|
||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from langchain.docstore.base import Docstore
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import AgentAction
|
||||||
|
|
||||||
_PAGE_CONTENT = """This is a page about LangChain.
|
_PAGE_CONTENT = """This is a page about LangChain.
|
||||||
|
|
||||||
@ -61,10 +62,9 @@ def test_predict_until_observation_normal() -> None:
|
|||||||
Tool("Lookup", lambda x: x),
|
Tool("Lookup", lambda x: x),
|
||||||
]
|
]
|
||||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||||
output = agent.get_action("", {"input": ""})
|
output = agent.plan([], input="")
|
||||||
assert output.log == outputs[0]
|
expected_output = AgentAction("Search", "foo", outputs[0])
|
||||||
assert output.tool == "Search"
|
assert output == expected_output
|
||||||
assert output.tool_input == "foo"
|
|
||||||
|
|
||||||
|
|
||||||
def test_predict_until_observation_repeat() -> None:
|
def test_predict_until_observation_repeat() -> None:
|
||||||
@ -76,10 +76,9 @@ def test_predict_until_observation_repeat() -> None:
|
|||||||
Tool("Lookup", lambda x: x),
|
Tool("Lookup", lambda x: x),
|
||||||
]
|
]
|
||||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||||
output = agent.get_action("", {"input": ""})
|
output = agent.plan([], input="")
|
||||||
assert output.log == "foo\nAction 1: Search[foo]"
|
expected_output = AgentAction("Search", "foo", "foo\nAction 1: Search[foo]")
|
||||||
assert output.tool == "Search"
|
assert output == expected_output
|
||||||
assert output.tool_input == "foo"
|
|
||||||
|
|
||||||
|
|
||||||
def test_react_chain() -> None:
|
def test_react_chain() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user