This commit is contained in:
Harrison Chase 2022-12-17 14:21:55 -08:00
parent 3dd367cc60
commit 85e7c5fd6c
4 changed files with 162 additions and 45 deletions

View File

@ -13,7 +13,158 @@ from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction from langchain.schema import AgentAction, AgentFinish
import langchain
from typing import NamedTuple
class Planner(BaseModel):
"""Class responsible for calling the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include
a variable called "agent_scratchpad" where the agent can put its
intermediary work.
"""
llm_chain: LLMChain
return_values: List[str]
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract tool and tool input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this agent.")
@property
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
def plan(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any) -> AgentAction:
"""Given input, decided what to do.
Args:
thoughts: LLM thoughts
inputs: user inputs
Returns:
Action specifying what tool to use.
"""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
full_output = self.llm_chain.predict(**full_inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
full_inputs["agent_scratchpad"] += full_output
output = self.llm_chain.predict(**full_inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self):
pass
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
@root_validator()
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be a variable in prompt.input_variables"
)
return values
@property
@abstractmethod
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
@property
@abstractmethod
def llm_prefix(self) -> str:
"""Prefix to append the LLM call with."""
class NewAgent(Chain, BaseModel):
planner: Planner
tools: List[Tool]
return_intermediate_steps: bool = False
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return self.planner.input_keys
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if self.return_intermediate_steps:
return self.planner.return_values + ["intermediate_steps"]
else:
return self.planner.return_values
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Do any preparation necessary when receiving a new input.
self.planner.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
planner_inputs = inputs.copy()
intermediate_steps = []
# We now enter the agent loop (until it returns something).
while True:
# Call the LLM to see what to do.
output = self.planner.plan(intermediate_steps, **planner_inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
langchain.logger.log_agent_end(output, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
if self.verbose:
langchain.logger.log_agent_action(output, color="green")
# And then we lookup the tool
if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool]
# We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."
color = None
if self.verbose:
langchain.logger.log_agent_observation(observation, color=color)
intermediate_steps.append((output, observation))
class Agent(Chain, BaseModel, ABC): class Agent(Chain, BaseModel, ABC):
@ -26,14 +177,6 @@ class Agent(Chain, BaseModel, ABC):
input_key: str = "input" #: :meta private: input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the singular output key. """Return the singular output key.
@ -75,17 +218,9 @@ class Agent(Chain, BaseModel, ABC):
"""Put this string after user input but before first LLM call.""" """Put this string after user input but before first LLM call."""
return "\n" return "\n"
@abstractmethod
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract tool and tool input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this agent.")
@property
def _stop(self) -> List[str]:
return [f"\n{self.observation_prefix}"]
@classmethod @classmethod
def _validate_tools(cls, tools: List[Tool]) -> None: def _validate_tools(cls, tools: List[Tool]) -> None:
@ -107,29 +242,6 @@ class Agent(Chain, BaseModel, ABC):
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
return cls(llm_chain=llm_chain, tools=tools, **kwargs) return cls(llm_chain=llm_chain, tools=tools, **kwargs)
def get_action(self, thoughts: str, inputs: dict) -> AgentAction:
"""Given input, decided what to do.
Args:
thoughts: LLM thoughts
inputs: user inputs
Returns:
Action specifying what tool to use.
"""
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**inputs, **new_inputs}
full_output = self.llm_chain.predict(**full_inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
full_inputs["agent_scratchpad"] += full_output
output = self.llm_chain.predict(**full_inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
return AgentAction(tool, tool_input, full_output)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response.""" """Run text through and get agent response."""
# Do any preparation necessary when receiving a new input. # Do any preparation necessary when receiving a new input.

View File

@ -33,8 +33,7 @@ class ChainedInput:
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None: def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode.""" """Add text to input, print if in verbose mode."""
if self._verbose:
langchain.logger.log_agent_action(action, color=color)
self._input += action.log self._input += action.log
self._intermediate_actions.append(action) self._intermediate_actions.append(action)

View File

@ -2,7 +2,7 @@
from typing import Any, Optional from typing import Any, Optional
from langchain.input import print_text from langchain.input import print_text
from langchain.schema import AgentAction from langchain.schema import AgentAction, AgentFinish
class BaseLogger: class BaseLogger:
@ -12,7 +12,7 @@ class BaseLogger:
"""Log the start of an agent interaction.""" """Log the start of an agent interaction."""
pass pass
def log_agent_end(self, text: str, **kwargs: Any) -> None: def log_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Log the end of an agent interaction.""" """Log the end of an agent interaction."""
pass pass

View File

@ -11,6 +11,12 @@ class AgentAction(NamedTuple):
log: str log: str
class AgentFinish(NamedTuple):
"""Agent's return value."""
return_values: dict
log: str
class Generation(NamedTuple): class Generation(NamedTuple):
"""Output of a single generation.""" """Output of a single generation."""