mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
langchain: Fix type warnings when passing Runnable as agent to AgentExecutor (#24750)
Fix for https://github.com/langchain-ai/langchain/issues/13075 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
61228da1c4
commit
b4fcda7657
@ -19,6 +19,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import yaml
|
||||
@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
|
||||
|
||||
|
||||
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
||||
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
||||
|
||||
|
||||
class AgentExecutor(Chain):
|
||||
"""Agent that is using tools."""
|
||||
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
|
||||
"""The agent to run for creating a plan and determining actions
|
||||
to take at each step of the execution loop."""
|
||||
tools: Sequence[BaseTool]
|
||||
@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
|
||||
@classmethod
|
||||
def from_agent_and_tools(
|
||||
cls,
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
|
||||
tools: Sequence[BaseTool],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
"""Type cast self.agent.
|
||||
|
||||
The .agent attribute type includes Runnable, but is converted to one of
|
||||
RunnableAgentType in the validate_runnable_agent root_validator.
|
||||
|
||||
To support instantiating with a Runnable, here we explicitly cast the type
|
||||
to reflect the changes made in the root_validator.
|
||||
"""
|
||||
if isinstance(self.agent, Runnable):
|
||||
return cast(RunnableAgentType, self.agent)
|
||||
else:
|
||||
return self.agent
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Raise error - saving not supported for Agent Executors.
|
||||
|
||||
@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
|
||||
Args:
|
||||
file_path: Path to save to.
|
||||
"""
|
||||
return self.agent.save(file_path)
|
||||
return self._action_agent.save(file_path)
|
||||
|
||||
def iter(
|
||||
self,
|
||||
@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.agent.input_keys
|
||||
return self._action_agent.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
|
||||
:meta private:
|
||||
"""
|
||||
if self.return_intermediate_steps:
|
||||
return self.agent.return_values + ["intermediate_steps"]
|
||||
return self._action_agent.return_values + ["intermediate_steps"]
|
||||
else:
|
||||
return self.agent.return_values
|
||||
return self._action_agent.return_values
|
||||
|
||||
def lookup_tool(self, name: str) -> BaseTool:
|
||||
"""Lookup tool by name.
|
||||
@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(
|
||||
output = self._action_agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
|
||||
output = AgentAction("_Exception", observation, text)
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(output, color="green")
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
observation = ExceptionTool().run(
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = await self.agent.aplan(
|
||||
output = await self._action_agent.aplan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
|
||||
else:
|
||||
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
||||
output = AgentAction("_Exception", observation, text)
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
observation = await ExceptionTool().arun(
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||
observation = await InvalidTool().arun(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
output = self._action_agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return self._return(output, intermediate_steps, run_manager=run_manager)
|
||||
@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
output = self._action_agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(
|
||||
@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
except (TimeoutError, asyncio.TimeoutError):
|
||||
# stop early when interrupted by the async timeout
|
||||
output = self.agent.return_stopped_response(
|
||||
output = self._action_agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(
|
||||
@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
|
||||
agent_action, observation = next_step_output
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
return_value_key = "output"
|
||||
if len(self.agent.return_values) > 0:
|
||||
return_value_key = self.agent.return_values[0]
|
||||
if len(self._action_agent.return_values) > 0:
|
||||
return_value_key = self._action_agent.return_values[0]
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
|
@ -371,7 +371,7 @@ class AgentExecutorIterator:
|
||||
"""
|
||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||
# this manually constructs agent finish with output key
|
||||
output = self.agent_executor.agent.return_stopped_response(
|
||||
output = self.agent_executor._action_agent.return_stopped_response(
|
||||
self.agent_executor.early_stopping_method,
|
||||
self.intermediate_steps,
|
||||
**self.inputs,
|
||||
@ -384,7 +384,7 @@ class AgentExecutorIterator:
|
||||
the stopped response.
|
||||
"""
|
||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||
output = self.agent_executor.agent.return_stopped_response(
|
||||
output = self.agent_executor._action_agent.return_stopped_response(
|
||||
self.agent_executor.early_stopping_method,
|
||||
self.intermediate_steps,
|
||||
**self.inputs,
|
||||
|
@ -21,6 +21,9 @@ def test_initialize_agent_with_str_agent_type() -> None:
|
||||
fake_llm,
|
||||
"zero-shot-react-description", # type: ignore[arg-type]
|
||||
)
|
||||
assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
assert (
|
||||
agent_executor._action_agent._agent_type
|
||||
== AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
)
|
||||
assert isinstance(agent_executor.tags, list)
|
||||
assert "zero-shot-react-description" in agent_executor.tags
|
||||
|
Loading…
Reference in New Issue
Block a user