diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index 16f648bea9e..50797228a53 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -19,6 +19,7 @@ from typing import ( Sequence, Tuple, Union, + cast, ) import yaml @@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool): NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]] +RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent] class AgentExecutor(Chain): """Agent that is using tools.""" - agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] + agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable] """The agent to run for creating a plan and determining actions to take at each step of the execution loop.""" tools: Sequence[BaseTool] @@ -1095,7 +1097,7 @@ class AgentExecutor(Chain): @classmethod def from_agent_and_tools( cls, - agent: Union[BaseSingleActionAgent, BaseMultiActionAgent], + agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable], tools: Sequence[BaseTool], callbacks: Callbacks = None, **kwargs: Any, @@ -1172,6 +1174,21 @@ class AgentExecutor(Chain): ) return values + @property + def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: + """Type cast self.agent. + + The .agent attribute type includes Runnable, but is converted to one of + RunnableAgentType in the validate_runnable_agent root_validator. + + To support instantiating with a Runnable, here we explicitly cast the type + to reflect the changes made in the root_validator. + """ + if isinstance(self.agent, Runnable): + return cast(RunnableAgentType, self.agent) + else: + return self.agent + def save(self, file_path: Union[Path, str]) -> None: """Raise error - saving not supported for Agent Executors. @@ -1193,7 +1210,7 @@ class AgentExecutor(Chain): Args: file_path: Path to save to. """ - return self.agent.save(file_path) + return self._action_agent.save(file_path) def iter( self, @@ -1228,7 +1245,7 @@ class AgentExecutor(Chain): :meta private: """ - return self.agent.input_keys + return self._action_agent.input_keys @property def output_keys(self) -> List[str]: @@ -1237,9 +1254,9 @@ class AgentExecutor(Chain): :meta private: """ if self.return_intermediate_steps: - return self.agent.return_values + ["intermediate_steps"] + return self._action_agent.return_values + ["intermediate_steps"] else: - return self.agent.return_values + return self._action_agent.return_values def lookup_tool(self, name: str) -> BaseTool: """Lookup tool by name. @@ -1339,7 +1356,7 @@ class AgentExecutor(Chain): intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) # Call the LLM to see what to do. - output = self.agent.plan( + output = self._action_agent.plan( intermediate_steps, callbacks=run_manager.get_child() if run_manager else None, **inputs, @@ -1372,7 +1389,7 @@ class AgentExecutor(Chain): output = AgentAction("_Exception", observation, text) if run_manager: run_manager.on_agent_action(output, color="green") - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() observation = ExceptionTool().run( output.tool_input, verbose=self.verbose, @@ -1414,7 +1431,7 @@ class AgentExecutor(Chain): tool = name_to_tool_map[agent_action.tool] return_direct = tool.return_direct color = color_mapping[agent_action.tool] - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() if return_direct: tool_run_kwargs["llm_prefix"] = "" # We then call the tool on the tool input to get an observation @@ -1426,7 +1443,7 @@ class AgentExecutor(Chain): **tool_run_kwargs, ) else: - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() observation = InvalidTool().run( { "requested_tool_name": agent_action.tool, @@ -1476,7 +1493,7 @@ class AgentExecutor(Chain): intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) # Call the LLM to see what to do. - output = await self.agent.aplan( + output = await self._action_agent.aplan( intermediate_steps, callbacks=run_manager.get_child() if run_manager else None, **inputs, @@ -1507,7 +1524,7 @@ class AgentExecutor(Chain): else: raise ValueError("Got unexpected type of `handle_parsing_errors`") output = AgentAction("_Exception", observation, text) - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() observation = await ExceptionTool().arun( output.tool_input, verbose=self.verbose, @@ -1561,7 +1578,7 @@ class AgentExecutor(Chain): tool = name_to_tool_map[agent_action.tool] return_direct = tool.return_direct color = color_mapping[agent_action.tool] - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() if return_direct: tool_run_kwargs["llm_prefix"] = "" # We then call the tool on the tool input to get an observation @@ -1573,7 +1590,7 @@ class AgentExecutor(Chain): **tool_run_kwargs, ) else: - tool_run_kwargs = self.agent.tool_run_logging_kwargs() + tool_run_kwargs = self._action_agent.tool_run_logging_kwargs() observation = await InvalidTool().arun( { "requested_tool_name": agent_action.tool, @@ -1628,7 +1645,7 @@ class AgentExecutor(Chain): ) iterations += 1 time_elapsed = time.time() - start_time - output = self.agent.return_stopped_response( + output = self._action_agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return self._return(output, intermediate_steps, run_manager=run_manager) @@ -1680,7 +1697,7 @@ class AgentExecutor(Chain): iterations += 1 time_elapsed = time.time() - start_time - output = self.agent.return_stopped_response( + output = self._action_agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return await self._areturn( @@ -1688,7 +1705,7 @@ class AgentExecutor(Chain): ) except (TimeoutError, asyncio.TimeoutError): # stop early when interrupted by the async timeout - output = self.agent.return_stopped_response( + output = self._action_agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return await self._areturn( @@ -1702,8 +1719,8 @@ class AgentExecutor(Chain): agent_action, observation = next_step_output name_to_tool_map = {tool.name: tool for tool in self.tools} return_value_key = "output" - if len(self.agent.return_values) > 0: - return_value_key = self.agent.return_values[0] + if len(self._action_agent.return_values) > 0: + return_value_key = self._action_agent.return_values[0] # Invalid tools won't be in the map, so we return False. if agent_action.tool in name_to_tool_map: if name_to_tool_map[agent_action.tool].return_direct: diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index a2c51efd652..ddd742e1c67 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -371,7 +371,7 @@ class AgentExecutorIterator: """ logger.warning("Stopping agent prematurely due to triggering stop condition") # this manually constructs agent finish with output key - output = self.agent_executor.agent.return_stopped_response( + output = self.agent_executor._action_agent.return_stopped_response( self.agent_executor.early_stopping_method, self.intermediate_steps, **self.inputs, @@ -384,7 +384,7 @@ class AgentExecutorIterator: the stopped response. """ logger.warning("Stopping agent prematurely due to triggering stop condition") - output = self.agent_executor.agent.return_stopped_response( + output = self.agent_executor._action_agent.return_stopped_response( self.agent_executor.early_stopping_method, self.intermediate_steps, **self.inputs, diff --git a/libs/langchain/tests/unit_tests/agents/test_initialize.py b/libs/langchain/tests/unit_tests/agents/test_initialize.py index b83a549b53c..f898208a292 100644 --- a/libs/langchain/tests/unit_tests/agents/test_initialize.py +++ b/libs/langchain/tests/unit_tests/agents/test_initialize.py @@ -21,6 +21,9 @@ def test_initialize_agent_with_str_agent_type() -> None: fake_llm, "zero-shot-react-description", # type: ignore[arg-type] ) - assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION + assert ( + agent_executor._action_agent._agent_type + == AgentType.ZERO_SHOT_REACT_DESCRIPTION + ) assert isinstance(agent_executor.tags, list) assert "zero-shot-react-description" in agent_executor.tags