langchain: Fix type warnings when passing Runnable as agent to AgentExecutor (#24750)

Fix for https://github.com/langchain-ai/langchain/issues/13075

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Hasan Kumar 2024-08-22 08:02:02 -10:00 committed by GitHub
parent 61228da1c4
commit b4fcda7657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Union, Union,
cast,
) )
import yaml import yaml
@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]] NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
class AgentExecutor(Chain): class AgentExecutor(Chain):
"""Agent that is using tools.""" """Agent that is using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
"""The agent to run for creating a plan and determining actions """The agent to run for creating a plan and determining actions
to take at each step of the execution loop.""" to take at each step of the execution loop."""
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
cls, cls,
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
) )
return values return values
@property
def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Type cast self.agent.
The .agent attribute type includes Runnable, but is converted to one of
RunnableAgentType in the validate_runnable_agent root_validator.
To support instantiating with a Runnable, here we explicitly cast the type
to reflect the changes made in the root_validator.
"""
if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent)
else:
return self.agent
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors. """Raise error - saving not supported for Agent Executors.
@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
Args: Args:
file_path: Path to save to. file_path: Path to save to.
""" """
return self.agent.save(file_path) return self._action_agent.save(file_path)
def iter( def iter(
self, self,
@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
:meta private: :meta private:
""" """
return self.agent.input_keys return self._action_agent.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
:meta private: :meta private:
""" """
if self.return_intermediate_steps: if self.return_intermediate_steps:
return self.agent.return_values + ["intermediate_steps"] return self._action_agent.return_values + ["intermediate_steps"]
else: else:
return self.agent.return_values return self._action_agent.return_values
def lookup_tool(self, name: str) -> BaseTool: def lookup_tool(self, name: str) -> BaseTool:
"""Lookup tool by name. """Lookup tool by name.
@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = self.agent.plan( output = self._action_agent.plan(
intermediate_steps, intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**inputs, **inputs,
@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
output = AgentAction("_Exception", observation, text) output = AgentAction("_Exception", observation, text)
if run_manager: if run_manager:
run_manager.on_agent_action(output, color="green") run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = ExceptionTool().run( observation = ExceptionTool().run(
output.tool_input, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
tool = name_to_tool_map[agent_action.tool] tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct return_direct = tool.return_direct
color = color_mapping[agent_action.tool] color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
if return_direct: if return_direct:
tool_run_kwargs["llm_prefix"] = "" tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
**tool_run_kwargs, **tool_run_kwargs,
) )
else: else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = InvalidTool().run( observation = InvalidTool().run(
{ {
"requested_tool_name": agent_action.tool, "requested_tool_name": agent_action.tool,
@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = await self.agent.aplan( output = await self._action_agent.aplan(
intermediate_steps, intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**inputs, **inputs,
@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
else: else:
raise ValueError("Got unexpected type of `handle_parsing_errors`") raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text) output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun( observation = await ExceptionTool().arun(
output.tool_input, output.tool_input,
verbose=self.verbose, verbose=self.verbose,
@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
tool = name_to_tool_map[agent_action.tool] tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct return_direct = tool.return_direct
color = color_mapping[agent_action.tool] color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
if return_direct: if return_direct:
tool_run_kwargs["llm_prefix"] = "" tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
**tool_run_kwargs, **tool_run_kwargs,
) )
else: else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs() tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun( observation = await InvalidTool().arun(
{ {
"requested_tool_name": agent_action.tool, "requested_tool_name": agent_action.tool,
@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
) )
iterations += 1 iterations += 1
time_elapsed = time.time() - start_time time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response( output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs self.early_stopping_method, intermediate_steps, **inputs
) )
return self._return(output, intermediate_steps, run_manager=run_manager) return self._return(output, intermediate_steps, run_manager=run_manager)
@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
iterations += 1 iterations += 1
time_elapsed = time.time() - start_time time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response( output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs self.early_stopping_method, intermediate_steps, **inputs
) )
return await self._areturn( return await self._areturn(
@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
) )
except (TimeoutError, asyncio.TimeoutError): except (TimeoutError, asyncio.TimeoutError):
# stop early when interrupted by the async timeout # stop early when interrupted by the async timeout
output = self.agent.return_stopped_response( output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs self.early_stopping_method, intermediate_steps, **inputs
) )
return await self._areturn( return await self._areturn(
@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
agent_action, observation = next_step_output agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools} name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output" return_value_key = "output"
if len(self.agent.return_values) > 0: if len(self._action_agent.return_values) > 0:
return_value_key = self.agent.return_values[0] return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False. # Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map: if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct: if name_to_tool_map[agent_action.tool].return_direct:

View File

@ -371,7 +371,7 @@ class AgentExecutorIterator:
""" """
logger.warning("Stopping agent prematurely due to triggering stop condition") logger.warning("Stopping agent prematurely due to triggering stop condition")
# this manually constructs agent finish with output key # this manually constructs agent finish with output key
output = self.agent_executor.agent.return_stopped_response( output = self.agent_executor._action_agent.return_stopped_response(
self.agent_executor.early_stopping_method, self.agent_executor.early_stopping_method,
self.intermediate_steps, self.intermediate_steps,
**self.inputs, **self.inputs,
@ -384,7 +384,7 @@ class AgentExecutorIterator:
the stopped response. the stopped response.
""" """
logger.warning("Stopping agent prematurely due to triggering stop condition") logger.warning("Stopping agent prematurely due to triggering stop condition")
output = self.agent_executor.agent.return_stopped_response( output = self.agent_executor._action_agent.return_stopped_response(
self.agent_executor.early_stopping_method, self.agent_executor.early_stopping_method,
self.intermediate_steps, self.intermediate_steps,
**self.inputs, **self.inputs,

View File

@ -21,6 +21,9 @@ def test_initialize_agent_with_str_agent_type() -> None:
fake_llm, fake_llm,
"zero-shot-react-description", # type: ignore[arg-type] "zero-shot-react-description", # type: ignore[arg-type]
) )
assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION assert (
agent_executor._action_agent._agent_type
== AgentType.ZERO_SHOT_REACT_DESCRIPTION
)
assert isinstance(agent_executor.tags, list) assert isinstance(agent_executor.tags, list)
assert "zero-shot-react-description" in agent_executor.tags assert "zero-shot-react-description" in agent_executor.tags