langchain: docstrings in agents root (#23561)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-06-27 12:52:18 -07:00 committed by GitHub
parent b64c4b4750
commit c0fdbaac85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 324 additions and 54 deletions

View File

@ -15,19 +15,20 @@ Agents select and use **Tools** and **Toolkits** for actions.
OpenAIFunctionsAgent OpenAIFunctionsAgent
XMLAgent XMLAgent
Agent --> <name>Agent # Examples: ZeroShotAgent, ChatAgent Agent --> <name>Agent # Examples: ZeroShotAgent, ChatAgent
BaseMultiActionAgent --> OpenAIMultiFunctionsAgent BaseMultiActionAgent --> OpenAIMultiFunctionsAgent
**Main helpers:** **Main helpers:**
.. code-block:: .. code-block::
AgentType, AgentExecutor, AgentOutputParser, AgentExecutorIterator, AgentType, AgentExecutor, AgentOutputParser, AgentExecutorIterator,
AgentAction, AgentFinish AgentAction, AgentFinish
""" # noqa: E501 """ # noqa: E501
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any

View File

@ -77,7 +77,7 @@ class BaseSingleActionAgent(BaseModel):
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -92,11 +92,11 @@ class BaseSingleActionAgent(BaseModel):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Async given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -118,7 +118,20 @@ class BaseSingleActionAgent(BaseModel):
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.""" """Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
"""
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish( return AgentFinish(
@ -137,15 +150,30 @@ class BaseSingleActionAgent(BaseModel):
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools.
Args:
llm: Language model to use.
tools: Tools to use.
callback_manager: Callback manager to use.
**kwargs: Additional arguments.
Returns:
BaseSingleActionAgent: Agent object.
"""
raise NotImplementedError raise NotImplementedError
@property @property
def _agent_type(self) -> str: def _agent_type(self) -> str:
"""Return Identifier of agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent.
Returns:
Dict: Dictionary representation of agent.
"""
_dict = super().dict() _dict = super().dict()
try: try:
_type = self._agent_type _type = self._agent_type
@ -193,6 +221,7 @@ class BaseSingleActionAgent(BaseModel):
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {} return {}
@ -205,6 +234,11 @@ class BaseMultiActionAgent(BaseModel):
return ["output"] return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[List[str]]:
"""Get allowed tools.
Returns:
Optional[List[str]]: Allowed tools.
"""
return None return None
@abstractmethod @abstractmethod
@ -233,7 +267,7 @@ class BaseMultiActionAgent(BaseModel):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do. """Async given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
@ -259,7 +293,20 @@ class BaseMultiActionAgent(BaseModel):
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.""" """Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
"""
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "") return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
@ -270,7 +317,7 @@ class BaseMultiActionAgent(BaseModel):
@property @property
def _agent_type(self) -> str: def _agent_type(self) -> str:
"""Return Identifier of agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
@ -288,6 +335,10 @@ class BaseMultiActionAgent(BaseModel):
Args: Args:
file_path: Path to file to save the agent to. file_path: Path to file to save the agent to.
Raises:
NotImplementedError: If agent does not support saving.
ValueError: If file_path is not json or yaml.
Example: Example:
.. code-block:: python .. code-block:: python
@ -318,6 +369,8 @@ class BaseMultiActionAgent(BaseModel):
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return {} return {}
@ -332,15 +385,26 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
class MultiActionAgentOutputParser( class MultiActionAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]] BaseOutputParser[Union[List[AgentAction], AgentFinish]]
): ):
"""Base class for parsing agent output into agent actions/finish.""" """Base class for parsing agent output into agent actions/finish.
This is used for agents that can return multiple actions.
"""
@abstractmethod @abstractmethod
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
"""Parse text into agent actions/finish.""" """Parse text into agent actions/finish.
Args:
text: Text to parse.
Returns:
Union[List[AgentAction], AgentFinish]:
List of agent actions or agent finish.
"""
class RunnableAgent(BaseSingleActionAgent): class RunnableAgent(BaseSingleActionAgent):
"""Agent powered by runnables.""" """Agent powered by Runnables."""
runnable: Runnable[dict, Union[AgentAction, AgentFinish]] runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action.""" """Runnable to call to get agent action."""
@ -367,6 +431,7 @@ class RunnableAgent(BaseSingleActionAgent):
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys."""
return self.input_keys_arg return self.input_keys_arg
def plan( def plan(
@ -414,13 +479,13 @@ class RunnableAgent(BaseSingleActionAgent):
AgentAction, AgentAction,
AgentFinish, AgentFinish,
]: ]:
"""Based on past history and current inputs, decide what to do. """Async based on past history and current inputs, decide what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs **kwargs: User inputs.
Returns: Returns:
Action specifying what tool to use. Action specifying what tool to use.
@ -449,7 +514,7 @@ class RunnableAgent(BaseSingleActionAgent):
class RunnableMultiActionAgent(BaseMultiActionAgent): class RunnableMultiActionAgent(BaseMultiActionAgent):
"""Agent powered by runnables.""" """Agent powered by Runnables."""
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]] runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions.""" """Runnable to call to get agent actions."""
@ -531,11 +596,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
List[AgentAction], List[AgentAction],
AgentFinish, AgentFinish,
]: ]:
"""Based on past history and current inputs, decide what to do. """Async based on past history and current inputs, decide what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -630,11 +695,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Async given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -650,6 +715,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
return self.output_parser.parse(output) return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return { return {
"llm_prefix": "", "llm_prefix": "",
"observation_prefix": "" if len(self.stop) == 0 else self.stop[0], "observation_prefix": "" if len(self.stop) == 0 else self.stop[0],
@ -667,14 +733,17 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
class Agent(BaseSingleActionAgent): class Agent(BaseSingleActionAgent):
"""Agent that calls the language model and deciding the action. """Agent that calls the language model and deciding the action.
This is driven by an LLMChain. The prompt in the LLMChain MUST include This is driven by a LLMChain. The prompt in the LLMChain MUST include
a variable called "agent_scratchpad" where the agent can put its a variable called "agent_scratchpad" where the agent can put its
intermediary work. intermediary work.
""" """
llm_chain: LLMChain llm_chain: LLMChain
"""LLMChain to use for agent."""
output_parser: AgentOutputParser output_parser: AgentOutputParser
"""Output parser to use for agent."""
allowed_tools: Optional[List[str]] = None allowed_tools: Optional[List[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed."""
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
@ -683,14 +752,23 @@ class Agent(BaseSingleActionAgent):
return _dict return _dict
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[List[str]]:
"""Get allowed tools."""
return self.allowed_tools return self.allowed_tools
@property @property
def return_values(self) -> List[str]: def return_values(self) -> List[str]:
"""Return values of the agent."""
return ["output"] return ["output"]
def _fix_text(self, text: str) -> str: def _fix_text(self, text: str) -> str:
"""Fix the text.""" """Fix the text.
Args:
text: Text to fix.
Returns:
str: Fixed text.
"""
raise ValueError("fix_text not implemented for this agent.") raise ValueError("fix_text not implemented for this agent.")
@property @property
@ -720,7 +798,7 @@ class Agent(BaseSingleActionAgent):
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -737,11 +815,11 @@ class Agent(BaseSingleActionAgent):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Async given input, decided what to do.
Args: Args:
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations along with observations.
callbacks: Callbacks to run. callbacks: Callbacks to run.
**kwargs: User inputs. **kwargs: User inputs.
@ -756,7 +834,16 @@ class Agent(BaseSingleActionAgent):
def get_full_inputs( def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps.""" """Create the full inputs for the LLMChain from intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
Dict[str, Any]: Full inputs for the LLMChain.
"""
thoughts = self._construct_scratchpad(intermediate_steps) thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs} full_inputs = {**kwargs, **new_inputs}
@ -772,7 +859,18 @@ class Agent(BaseSingleActionAgent):
@root_validator(pre=False, skip_on_failure=True) @root_validator(pre=False, skip_on_failure=True)
def validate_prompt(cls, values: Dict) -> Dict: def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format.""" """Validate that prompt matches format.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If `agent_scratchpad` is not in prompt.input_variables
and prompt is not a FewShotPromptTemplate or a PromptTemplate.
"""
prompt = values["llm_chain"].prompt prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables: if "agent_scratchpad" not in prompt.input_variables:
logger.warning( logger.warning(
@ -801,11 +899,23 @@ class Agent(BaseSingleActionAgent):
@classmethod @classmethod
@abstractmethod @abstractmethod
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Create a prompt for this class.""" """Create a prompt for this class.
Args:
tools: Tools to use.
Returns:
BasePromptTemplate: Prompt template.
"""
@classmethod @classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
"""Validate that appropriate tools are passed in.""" """Validate that appropriate tools are passed in.
Args:
tools: Tools to use.
"""
pass pass
@classmethod @classmethod
@ -822,7 +932,18 @@ class Agent(BaseSingleActionAgent):
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools.
Args:
llm: Language model to use.
tools: Tools to use.
callback_manager: Callback manager to use.
output_parser: Output parser to use.
**kwargs: Additional arguments.
Returns:
Agent: Agent object.
"""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
@ -844,7 +965,20 @@ class Agent(BaseSingleActionAgent):
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.""" """Return response when agent has been stopped due to max iterations.
Args:
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not in ['force', 'generate'].
"""
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish( return AgentFinish(
@ -881,6 +1015,7 @@ class Agent(BaseSingleActionAgent):
) )
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> Dict:
"""Return logging kwargs for tool run."""
return { return {
"llm_prefix": self.llm_prefix, "llm_prefix": self.llm_prefix,
"observation_prefix": self.observation_prefix, "observation_prefix": self.observation_prefix,
@ -957,6 +1092,9 @@ class AgentExecutor(Chain):
trim_intermediate_steps: Union[ trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]] int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1 ] = -1
"""How to trim the intermediate steps before returning them.
Defaults to -1, which means no trimming.
"""
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
@ -966,7 +1104,17 @@ class AgentExecutor(Chain):
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Create from agent and tools.""" """Create from agent and tools.
Args:
agent: Agent to use.
tools: Tools to use.
callbacks: Callbacks to use.
**kwargs: Additional arguments.
Returns:
AgentExecutor: Agent executor object.
"""
return cls( return cls(
agent=agent, agent=agent,
tools=tools, tools=tools,
@ -976,7 +1124,17 @@ class AgentExecutor(Chain):
@root_validator(pre=False, skip_on_failure=True) @root_validator(pre=False, skip_on_failure=True)
def validate_tools(cls, values: Dict) -> Dict: def validate_tools(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent.""" """Validate that tools are compatible with agent.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If allowed tools are different than provided tools.
"""
agent = values["agent"] agent = values["agent"]
tools = values["tools"] tools = values["tools"]
allowed_tools = agent.get_allowed_tools() allowed_tools = agent.get_allowed_tools()
@ -990,7 +1148,17 @@ class AgentExecutor(Chain):
@root_validator(pre=False, skip_on_failure=True) @root_validator(pre=False, skip_on_failure=True)
def validate_return_direct_tool(cls, values: Dict) -> Dict: def validate_return_direct_tool(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent.""" """Validate that tools are compatible with agent.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
Raises:
ValueError: If tools that have `return_direct=True` are not allowed.
"""
agent = values["agent"] agent = values["agent"]
tools = values["tools"] tools = values["tools"]
if isinstance(agent, BaseMultiActionAgent): if isinstance(agent, BaseMultiActionAgent):
@ -1004,7 +1172,14 @@ class AgentExecutor(Chain):
@root_validator(pre=True) @root_validator(pre=True)
def validate_runnable_agent(cls, values: Dict) -> Dict: def validate_runnable_agent(cls, values: Dict) -> Dict:
"""Convert runnable to agent if passed in.""" """Convert runnable to agent if passed in.
Args:
values: Values to validate.
Returns:
Dict: Validated values.
"""
agent = values.get("agent") agent = values.get("agent")
if agent and isinstance(agent, Runnable): if agent and isinstance(agent, Runnable):
try: try:
@ -1026,7 +1201,14 @@ class AgentExecutor(Chain):
return values return values
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors.""" """Raise error - saving not supported for Agent Executors.
Args:
file_path: Path to save to.
Raises:
ValueError: Saving not supported for agent executors.
"""
raise ValueError( raise ValueError(
"Saving not supported for agent executors. " "Saving not supported for agent executors. "
"If you are trying to save the agent, please use the " "If you are trying to save the agent, please use the "
@ -1034,7 +1216,11 @@ class AgentExecutor(Chain):
) )
def save_agent(self, file_path: Union[Path, str]) -> None: def save_agent(self, file_path: Union[Path, str]) -> None:
"""Save the underlying agent.""" """Save the underlying agent.
Args:
file_path: Path to save to.
"""
return self.agent.save(file_path) return self.agent.save(file_path)
def iter( def iter(
@ -1045,7 +1231,17 @@ class AgentExecutor(Chain):
include_run_info: bool = False, include_run_info: bool = False,
async_: bool = False, # arg kept for backwards compat, but ignored async_: bool = False, # arg kept for backwards compat, but ignored
) -> AgentExecutorIterator: ) -> AgentExecutorIterator:
"""Enables iteration over steps taken to reach final output.""" """Enables iteration over steps taken to reach final output.
Args:
inputs: Inputs to the agent.
callbacks: Callbacks to run.
include_run_info: Whether to include run info.
async_: Whether to run async. (Ignored)
Returns:
AgentExecutorIterator: Agent executor iterator object.
"""
return AgentExecutorIterator( return AgentExecutorIterator(
self, self,
inputs, inputs,
@ -1074,7 +1270,14 @@ class AgentExecutor(Chain):
return self.agent.return_values return self.agent.return_values
def lookup_tool(self, name: str) -> BaseTool: def lookup_tool(self, name: str) -> BaseTool:
"""Lookup tool by name.""" """Lookup tool by name.
Args:
name: Name of tool.
Returns:
BaseTool: Tool object.
"""
return {tool.name: tool for tool in self.tools}[name] return {tool.name: tool for tool in self.tools}[name]
def _should_continue(self, iterations: int, time_elapsed: float) -> bool: def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
@ -1463,7 +1666,7 @@ class AgentExecutor(Chain):
inputs: Dict[str, str], inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Run text through and get agent response.""" """Async run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup # Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools} name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging. # We construct a mapping from each tool to a color, used for logging.
@ -1557,7 +1760,16 @@ class AgentExecutor(Chain):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
"""Enables streaming over steps taken to reach final output.""" """Enables streaming over steps taken to reach final output.
Args:
input: Input to the agent.
config: Config to use.
**kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
"""
config = ensure_config(config) config = ensure_config(config)
iterator = AgentExecutorIterator( iterator = AgentExecutorIterator(
self, self,
@ -1579,7 +1791,17 @@ class AgentExecutor(Chain):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:
"""Enables streaming over steps taken to reach final output.""" """Async enables streaming over steps taken to reach final output.
Args:
input: Input to the agent.
config: Config to use.
**kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
"""
config = ensure_config(config) config = ensure_config(config)
iterator = AgentExecutorIterator( iterator = AgentExecutorIterator(
self, self,

View File

@ -62,6 +62,22 @@ class AgentExecutorIterator:
""" """
Initialize the AgentExecutorIterator with the given AgentExecutor, Initialize the AgentExecutorIterator with the given AgentExecutor,
inputs, and optional callbacks. inputs, and optional callbacks.
Args:
agent_executor (AgentExecutor): The AgentExecutor to iterate over.
inputs (Any): The inputs to the AgentExecutor.
callbacks (Callbacks, optional): The callbacks to use during iteration.
Defaults to None.
tags (Optional[list[str]], optional): The tags to use during iteration.
Defaults to None.
metadata (Optional[Dict[str, Any]], optional): The metadata to use
during iteration. Defaults to None.
run_name (Optional[str], optional): The name of the run. Defaults to None.
run_id (Optional[UUID], optional): The ID of the run. Defaults to None.
include_run_info (bool, optional): Whether to include run info
in the output. Defaults to False.
yield_actions (bool, optional): Whether to yield actions as they
are generated. Defaults to False.
""" """
self._agent_executor = agent_executor self._agent_executor = agent_executor
self.inputs = inputs self.inputs = inputs
@ -85,6 +101,7 @@ class AgentExecutorIterator:
@property @property
def inputs(self) -> Dict[str, str]: def inputs(self) -> Dict[str, str]:
"""The inputs to the AgentExecutor."""
return self._inputs return self._inputs
@inputs.setter @inputs.setter
@ -93,6 +110,7 @@ class AgentExecutorIterator:
@property @property
def agent_executor(self) -> AgentExecutor: def agent_executor(self) -> AgentExecutor:
"""The AgentExecutor to iterate over."""
return self._agent_executor return self._agent_executor
@agent_executor.setter @agent_executor.setter
@ -103,10 +121,12 @@ class AgentExecutorIterator:
@property @property
def name_to_tool_map(self) -> Dict[str, BaseTool]: def name_to_tool_map(self) -> Dict[str, BaseTool]:
"""A mapping of tool names to tools."""
return {tool.name: tool for tool in self.agent_executor.tools} return {tool.name: tool for tool in self.agent_executor.tools}
@property @property
def color_mapping(self) -> Dict[str, str]: def color_mapping(self) -> Dict[str, str]:
"""A mapping of tool names to colors."""
return get_color_mapping( return get_color_mapping(
[tool.name for tool in self.agent_executor.tools], [tool.name for tool in self.agent_executor.tools],
excluded_colors=["green", "red"], excluded_colors=["green", "red"],

View File

@ -1,4 +1,5 @@
"""Module definitions of agent types together with corresponding agents.""" """Module definitions of agent types together with corresponding agents."""
from enum import Enum from enum import Enum
from langchain_core._api import deprecated from langchain_core._api import deprecated

View File

@ -1,4 +1,5 @@
"""Load agent.""" """Load agent."""
from typing import Any, Optional, Sequence from typing import Any, Optional, Sequence
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -35,17 +36,24 @@ def initialize_agent(
Args: Args:
tools: List of tools this agent has access to. tools: List of tools this agent has access to.
llm: Language model to use as the agent. llm: Language model to use as the agent.
agent: Agent type to use. If None and agent_path is also None, will default to agent: Agent type to use. If None and agent_path is also None, will default
AgentType.ZERO_SHOT_REACT_DESCRIPTION. to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None.
callback_manager: CallbackManager to use. Global callback manager is used if callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None. not provided. Defaults to None.
agent_path: Path to serialized agent to use. agent_path: Path to serialized agent to use. If None and agent is also None,
agent_kwargs: Additional keyword arguments to pass to the underlying agent will default to AgentType.ZERO_SHOT_REACT_DESCRIPTION. Defaults to None.
tags: Tags to apply to the traced runs. agent_kwargs: Additional keyword arguments to pass to the underlying agent.
**kwargs: Additional keyword arguments passed to the agent executor Defaults to None.
tags: Tags to apply to the traced runs. Defaults to None.
**kwargs: Additional keyword arguments passed to the agent executor.
Returns: Returns:
An agent executor An agent executor.
Raises:
ValueError: If both `agent` and `agent_path` are specified.
ValueError: If `agent` is not a valid agent type.
ValueError: If both `agent` and `agent_path` are None.
""" """
tags_ = list(tags) if tags else [] tags_ = list(tags) if tags else []
if agent is None and agent_path is None: if agent is None and agent_path is None:

View File

@ -48,6 +48,9 @@ def load_agent_from_config(
Returns: Returns:
An agent executor. An agent executor.
Raises:
ValueError: If agent type is not specified in the config.
""" """
if "_type" not in config: if "_type" not in config:
raise ValueError("Must specify an agent Type in config") raise ValueError("Must specify an agent Type in config")
@ -99,6 +102,10 @@ def load_agent(
Returns: Returns:
An agent executor. An agent executor.
Raises:
RuntimeError: If loading from the deprecated github-based
Hub is attempted.
""" """
if isinstance(path, str) and path.startswith("lc://"): if isinstance(path, str) and path.startswith("lc://"):
raise RuntimeError( raise RuntimeError(

View File

@ -1,4 +1,5 @@
"""Interface for tools.""" """Interface for tools."""
from typing import List, Optional from typing import List, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -12,7 +13,9 @@ class InvalidTool(BaseTool):
"""Tool that is run when invalid tool name is encountered by agent.""" """Tool that is run when invalid tool name is encountered by agent."""
name: str = "invalid_tool" name: str = "invalid_tool"
"""Name of the tool."""
description: str = "Called when tool name is invalid. Suggests valid tool names." description: str = "Called when tool name is invalid. Suggests valid tool names."
"""Description of the tool."""
def _run( def _run(
self, self,

View File

@ -4,7 +4,15 @@ from langchain_core.tools import BaseTool
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None: def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
"""Validate tools for single input.""" """Validate tools for single input.
Args:
class_name: Name of the class.
tools: List of tools to validate.
Raises:
ValueError: If a multi-input tool is found in tools.
"""
for tool in tools: for tool in tools:
if not tool.is_single_input: if not tool.is_single_input:
raise ValueError( raise ValueError(