diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 62cc4641120..99137fd4b60 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -13,7 +13,158 @@ from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate -from langchain.schema import AgentAction +from langchain.schema import AgentAction, AgentFinish +import langchain + + +from typing import NamedTuple + + + +class Planner(BaseModel): + """Class responsible for calling the language model and deciding the action. + + This is driven by an LLMChain. The prompt in the LLMChain MUST include + a variable called "agent_scratchpad" where the agent can put its + intermediary work. + """ + llm_chain: LLMChain + return_values: List[str] + + @abstractmethod + def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: + """Extract tool and tool input from llm output.""" + + def _fix_text(self, text: str) -> str: + """Fix the text.""" + raise ValueError("fix_text not implemented for this agent.") + + @property + def _stop(self) -> List[str]: + return [f"\n{self.observation_prefix}"] + + def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction: + """Given input, decided what to do. + + Args: + thoughts: LLM thoughts + inputs: user inputs + + Returns: + Action specifying what tool to use. + """ + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} + full_inputs = {**kwargs, **new_inputs} + full_output = self.llm_chain.predict(**full_inputs) + parsed_output = self._extract_tool_and_input(full_output) + while parsed_output is None: + full_output = self._fix_text(full_output) + full_inputs["agent_scratchpad"] += full_output + output = self.llm_chain.predict(**full_inputs) + full_output += output + parsed_output = self._extract_tool_and_input(full_output) + tool, tool_input = parsed_output + return AgentAction(tool, tool_input, full_output) + + def prepare_for_new_call(self): + pass + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) + + @root_validator() + def validate_prompt(cls, values: Dict) -> Dict: + """Validate that prompt matches format.""" + prompt = values["llm_chain"].prompt + if "agent_scratchpad" not in prompt.input_variables: + raise ValueError( + "`agent_scratchpad` should be a variable in prompt.input_variables" + ) + return values + + @property + @abstractmethod + def observation_prefix(self) -> str: + """Prefix to append the observation with.""" + + @property + @abstractmethod + def llm_prefix(self) -> str: + """Prefix to append the LLM call with.""" + + +class NewAgent(Chain, BaseModel): + + planner: Planner + tools: List[Tool] + return_intermediate_steps: bool = False + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return self.planner.input_keys + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + if self.return_intermediate_steps: + return self.planner.return_values + ["intermediate_steps"] + else: + return self.planner.return_values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + """Run text through and get agent response.""" + # Do any preparation necessary when receiving a new input. + self.planner.prepare_for_new_call() + # Construct a mapping of tool name to tool for easy lookup + name_to_tool_map = {tool.name: tool.func for tool in self.tools} + # We construct a mapping from each tool to a color, used for logging. + color_mapping = get_color_mapping( + [tool.name for tool in self.tools], excluded_colors=["green"] + ) + planner_inputs = inputs.copy() + intermediate_steps = [] + # We now enter the agent loop (until it returns something). + while True: + # Call the LLM to see what to do. + output = self.planner.plan(intermediate_steps, **planner_inputs) + # If the tool chosen is the finishing tool, then we end and return. + if isinstance(output, AgentFinish): + if self.verbose: + langchain.logger.log_agent_end(output, color="green") + final_output = output.return_values + if self.return_intermediate_steps: + final_output["intermediate_steps"] = intermediate_steps + return final_output + if self.verbose: + langchain.logger.log_agent_action(output, color="green") + # And then we lookup the tool + if output.tool in name_to_tool_map: + chain = name_to_tool_map[output.tool] + # We then call the tool on the tool input to get an observation + observation = chain(output.tool_input) + color = color_mapping[output.tool] + else: + observation = f"{output.tool} is not a valid tool, try another one." + color = None + if self.verbose: + langchain.logger.log_agent_observation(observation, color=color) + intermediate_steps.append((output, observation)) class Agent(Chain, BaseModel, ABC): @@ -26,14 +177,6 @@ class Agent(Chain, BaseModel, ABC): input_key: str = "input" #: :meta private: output_key: str = "output" #: :meta private: - @property - def input_keys(self) -> List[str]: - """Return the input keys. - - :meta private: - """ - return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) - @property def output_keys(self) -> List[str]: """Return the singular output key. @@ -75,17 +218,9 @@ class Agent(Chain, BaseModel, ABC): """Put this string after user input but before first LLM call.""" return "\n" - @abstractmethod - def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: - """Extract tool and tool input from llm output.""" - def _fix_text(self, text: str) -> str: - """Fix the text.""" - raise ValueError("fix_text not implemented for this agent.") - @property - def _stop(self) -> List[str]: - return [f"\n{self.observation_prefix}"] + @classmethod def _validate_tools(cls, tools: List[Tool]) -> None: @@ -107,29 +242,6 @@ class Agent(Chain, BaseModel, ABC): llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) return cls(llm_chain=llm_chain, tools=tools, **kwargs) - def get_action(self, thoughts: str, inputs: dict) -> AgentAction: - """Given input, decided what to do. - - Args: - thoughts: LLM thoughts - inputs: user inputs - - Returns: - Action specifying what tool to use. - """ - new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} - full_inputs = {**inputs, **new_inputs} - full_output = self.llm_chain.predict(**full_inputs) - parsed_output = self._extract_tool_and_input(full_output) - while parsed_output is None: - full_output = self._fix_text(full_output) - full_inputs["agent_scratchpad"] += full_output - output = self.llm_chain.predict(**full_inputs) - full_output += output - parsed_output = self._extract_tool_and_input(full_output) - tool, tool_input = parsed_output - return AgentAction(tool, tool_input, full_output) - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Do any preparation necessary when receiving a new input. diff --git a/langchain/agents/input.py b/langchain/agents/input.py index 5e137882e73..24a2e41d0c6 100644 --- a/langchain/agents/input.py +++ b/langchain/agents/input.py @@ -33,8 +33,7 @@ class ChainedInput: def add_action(self, action: AgentAction, color: Optional[str] = None) -> None: """Add text to input, print if in verbose mode.""" - if self._verbose: - langchain.logger.log_agent_action(action, color=color) + self._input += action.log self._intermediate_actions.append(action) diff --git a/langchain/logger.py b/langchain/logger.py index 27b8fd84af1..af7a0f7bbaf 100644 --- a/langchain/logger.py +++ b/langchain/logger.py @@ -2,7 +2,7 @@ from typing import Any, Optional from langchain.input import print_text -from langchain.schema import AgentAction +from langchain.schema import AgentAction, AgentFinish class BaseLogger: @@ -12,7 +12,7 @@ class BaseLogger: """Log the start of an agent interaction.""" pass - def log_agent_end(self, text: str, **kwargs: Any) -> None: + def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: """Log the end of an agent interaction.""" pass diff --git a/langchain/schema.py b/langchain/schema.py index 4e255e3e5ea..b620dc17313 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -11,6 +11,12 @@ class AgentAction(NamedTuple): log: str +class AgentFinish(NamedTuple): + """Agent's return value.""" + return_values: dict + log: str + + class Generation(NamedTuple): """Output of a single generation."""