langchain: Fix type warnings when passing Runnable as agent to AgentExecutor (#24750)

Fix for https://github.com/langchain-ai/langchain/issues/13075

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Hasan Kumar 2024-08-22 08:02:02 -10:00 committed by GitHub
parent 61228da1c4
commit b4fcda7657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from typing import (
Sequence,
Tuple,
Union,
cast,
)
import yaml
@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
class AgentExecutor(Chain):
"""Agent that is using tools."""
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
"""The agent to run for creating a plan and determining actions
to take at each step of the execution loop."""
tools: Sequence[BaseTool]
@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
@classmethod
def from_agent_and_tools(
cls,
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
tools: Sequence[BaseTool],
callbacks: Callbacks = None,
**kwargs: Any,
@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
)
return values
@property
def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Type cast self.agent.
The .agent attribute type includes Runnable, but is converted to one of
RunnableAgentType in the validate_runnable_agent root_validator.
To support instantiating with a Runnable, here we explicitly cast the type
to reflect the changes made in the root_validator.
"""
if isinstance(self.agent, Runnable):
return cast(RunnableAgentType, self.agent)
else:
return self.agent
def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors.
@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
Args:
file_path: Path to save to.
"""
return self.agent.save(file_path)
return self._action_agent.save(file_path)
def iter(
self,
@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
:meta private:
"""
return self.agent.input_keys
return self._action_agent.input_keys
@property
def output_keys(self) -> List[str]:
@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
:meta private:
"""
if self.return_intermediate_steps:
return self.agent.return_values + ["intermediate_steps"]
return self._action_agent.return_values + ["intermediate_steps"]
else:
return self.agent.return_values
return self._action_agent.return_values
def lookup_tool(self, name: str) -> BaseTool:
"""Lookup tool by name.
@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do.
output = self.agent.plan(
output = self._action_agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=self.verbose,
@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do.
output = await self.agent.aplan(
output = await self._action_agent.aplan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text)
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await ExceptionTool().arun(
output.tool_input,
verbose=self.verbose,
@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
observation = await InvalidTool().arun(
{
"requested_tool_name": agent_action.tool,
@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
)
iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps, run_manager=run_manager)
@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(
@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
)
except (TimeoutError, asyncio.TimeoutError):
# stop early when interrupted by the async timeout
output = self.agent.return_stopped_response(
output = self._action_agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(
@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output"
if len(self.agent.return_values) > 0:
return_value_key = self.agent.return_values[0]
if len(self._action_agent.return_values) > 0:
return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:

View File

@ -371,7 +371,7 @@ class AgentExecutorIterator:
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
# this manually constructs agent finish with output key
output = self.agent_executor.agent.return_stopped_response(
output = self.agent_executor._action_agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,
@ -384,7 +384,7 @@ class AgentExecutorIterator:
the stopped response.
"""
logger.warning("Stopping agent prematurely due to triggering stop condition")
output = self.agent_executor.agent.return_stopped_response(
output = self.agent_executor._action_agent.return_stopped_response(
self.agent_executor.early_stopping_method,
self.intermediate_steps,
**self.inputs,

View File

@ -21,6 +21,9 @@ def test_initialize_agent_with_str_agent_type() -> None:
fake_llm,
"zero-shot-react-description", # type: ignore[arg-type]
)
assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION
assert (
agent_executor._action_agent._agent_type
== AgentType.ZERO_SHOT_REACT_DESCRIPTION
)
assert isinstance(agent_executor.tags, list)
assert "zero-shot-react-description" in agent_executor.tags