This commit is contained in:
Harrison Chase 2022-12-17 14:21:55 -08:00
parent 3dd367cc60
commit 85e7c5fd6c
4 changed files with 162 additions and 45 deletions

View File

@ -13,7 +13,158 @@ from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction
from langchain.schema import AgentAction, AgentFinish
import langchain
from typing import NamedTuple
class Planner(BaseModel):
"""Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
a variable called "agent_scratchpad" where the agent can put its
intermediary work.
"""
llm_chain: LLMChain
return_values: List[str]
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract tool and tool input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this agent.")
@property
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction:
"""Given input, decided what to do.
Args:
thoughts: LLM thoughts
inputs: user inputs
Returns:
Action specifying what tool to use.
"""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
full_output = self.llm_chain.predict(**full_inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
full_inputs["agent_scratchpad"] += full_output
output = self.llm_chain.predict(**full_inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self):
pass
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
@root_validator()
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be a variable in prompt.input_variables"
)
return values
@property
@abstractmethod
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
@property
@abstractmethod
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
class NewAgent(Chain, BaseModel):
planner: Planner
tools: List[Tool]
return_intermediate_steps: bool = False
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return self.planner.input_keys
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if self.return_intermediate_steps:
return self.planner.return_values + ["intermediate_steps"]
else:
return self.planner.return_values
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.
self.planner.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
planner_inputs = inputs.copy()
intermediate_steps = []
# We now enter the agent loop (until it returns something).
while True:
# Call the LLM to see what to do.
output = self.planner.plan(intermediate_steps, **planner_inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
langchain.logger.log_agent_end(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
if self.verbose:
langchain.logger.log_agent_action(output, color="green")
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."
color = None
if self.verbose:
langchain.logger.log_agent_observation(observation, color=color)
intermediate_steps.append((output, observation))
class Agent(Chain, BaseModel, ABC):
@ -26,14 +177,6 @@ class Agent(Chain, BaseModel, ABC):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
@ -75,17 +218,9 @@ class Agent(Chain, BaseModel, ABC):
"""Put this string after user input but before first LLM call."""
return "\n"
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract tool and tool input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this agent.")
@property
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
@ -107,29 +242,6 @@ class Agent(Chain, BaseModel, ABC):
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain, tools=tools, **kwargs)
def get_action(self, thoughts: str, inputs: dict) -> AgentAction:
"""Given input, decided what to do.
Args:
thoughts: LLM thoughts
inputs: user inputs
Returns:
Action specifying what tool to use.
"""
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**inputs, **new_inputs}
full_output = self.llm_chain.predict(**full_inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
full_inputs["agent_scratchpad"] += full_output
output = self.llm_chain.predict(**full_inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
return AgentAction(tool, tool_input, full_output)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.

View File

@ -33,8 +33,7 @@ class ChainedInput:
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode."""
if self._verbose:
langchain.logger.log_agent_action(action, color=color)
self._input += action.log
self._intermediate_actions.append(action)

View File

@ -2,7 +2,7 @@
from typing import Any, Optional
from langchain.input import print_text
from langchain.schema import AgentAction
from langchain.schema import AgentAction, AgentFinish
class BaseLogger:
@ -12,7 +12,7 @@ class BaseLogger:
"""Log the start of an agent interaction."""
pass
def log_agent_end(self, text: str, **kwargs: Any) -> None:
def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Log the end of an agent interaction."""
pass

View File

@ -11,6 +11,12 @@ class AgentAction(NamedTuple):
log: str
class AgentFinish(NamedTuple):
"""Agent's return value."""
return_values: dict
log: str
class Generation(NamedTuple):
"""Output of a single generation."""