mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
langchain: Fix type warnings when passing Runnable as agent to AgentExecutor (#24750)
Fix for https://github.com/langchain-ai/langchain/issues/13075 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
61228da1c4
commit
b4fcda7657
@ -19,6 +19,7 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -1042,12 +1043,13 @@ class ExceptionTool(BaseTool):
|
|||||||
|
|
||||||
|
|
||||||
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
||||||
|
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutor(Chain):
|
class AgentExecutor(Chain):
|
||||||
"""Agent that is using tools."""
|
"""Agent that is using tools."""
|
||||||
|
|
||||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable]
|
||||||
"""The agent to run for creating a plan and determining actions
|
"""The agent to run for creating a plan and determining actions
|
||||||
to take at each step of the execution loop."""
|
to take at each step of the execution loop."""
|
||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
@ -1095,7 +1097,7 @@ class AgentExecutor(Chain):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_agent_and_tools(
|
def from_agent_and_tools(
|
||||||
cls,
|
cls,
|
||||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent, Runnable],
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -1172,6 +1174,21 @@ class AgentExecutor(Chain):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _action_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||||
|
"""Type cast self.agent.
|
||||||
|
|
||||||
|
The .agent attribute type includes Runnable, but is converted to one of
|
||||||
|
RunnableAgentType in the validate_runnable_agent root_validator.
|
||||||
|
|
||||||
|
To support instantiating with a Runnable, here we explicitly cast the type
|
||||||
|
to reflect the changes made in the root_validator.
|
||||||
|
"""
|
||||||
|
if isinstance(self.agent, Runnable):
|
||||||
|
return cast(RunnableAgentType, self.agent)
|
||||||
|
else:
|
||||||
|
return self.agent
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Raise error - saving not supported for Agent Executors.
|
"""Raise error - saving not supported for Agent Executors.
|
||||||
|
|
||||||
@ -1193,7 +1210,7 @@ class AgentExecutor(Chain):
|
|||||||
Args:
|
Args:
|
||||||
file_path: Path to save to.
|
file_path: Path to save to.
|
||||||
"""
|
"""
|
||||||
return self.agent.save(file_path)
|
return self._action_agent.save(file_path)
|
||||||
|
|
||||||
def iter(
|
def iter(
|
||||||
self,
|
self,
|
||||||
@ -1228,7 +1245,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return self.agent.input_keys
|
return self._action_agent.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
@ -1237,9 +1254,9 @@ class AgentExecutor(Chain):
|
|||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return self.agent.return_values + ["intermediate_steps"]
|
return self._action_agent.return_values + ["intermediate_steps"]
|
||||||
else:
|
else:
|
||||||
return self.agent.return_values
|
return self._action_agent.return_values
|
||||||
|
|
||||||
def lookup_tool(self, name: str) -> BaseTool:
|
def lookup_tool(self, name: str) -> BaseTool:
|
||||||
"""Lookup tool by name.
|
"""Lookup tool by name.
|
||||||
@ -1339,7 +1356,7 @@ class AgentExecutor(Chain):
|
|||||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||||
|
|
||||||
# Call the LLM to see what to do.
|
# Call the LLM to see what to do.
|
||||||
output = self.agent.plan(
|
output = self._action_agent.plan(
|
||||||
intermediate_steps,
|
intermediate_steps,
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
**inputs,
|
**inputs,
|
||||||
@ -1372,7 +1389,7 @@ class AgentExecutor(Chain):
|
|||||||
output = AgentAction("_Exception", observation, text)
|
output = AgentAction("_Exception", observation, text)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_agent_action(output, color="green")
|
run_manager.on_agent_action(output, color="green")
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
observation = ExceptionTool().run(
|
observation = ExceptionTool().run(
|
||||||
output.tool_input,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
@ -1414,7 +1431,7 @@ class AgentExecutor(Chain):
|
|||||||
tool = name_to_tool_map[agent_action.tool]
|
tool = name_to_tool_map[agent_action.tool]
|
||||||
return_direct = tool.return_direct
|
return_direct = tool.return_direct
|
||||||
color = color_mapping[agent_action.tool]
|
color = color_mapping[agent_action.tool]
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
if return_direct:
|
if return_direct:
|
||||||
tool_run_kwargs["llm_prefix"] = ""
|
tool_run_kwargs["llm_prefix"] = ""
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
@ -1426,7 +1443,7 @@ class AgentExecutor(Chain):
|
|||||||
**tool_run_kwargs,
|
**tool_run_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
observation = InvalidTool().run(
|
observation = InvalidTool().run(
|
||||||
{
|
{
|
||||||
"requested_tool_name": agent_action.tool,
|
"requested_tool_name": agent_action.tool,
|
||||||
@ -1476,7 +1493,7 @@ class AgentExecutor(Chain):
|
|||||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||||
|
|
||||||
# Call the LLM to see what to do.
|
# Call the LLM to see what to do.
|
||||||
output = await self.agent.aplan(
|
output = await self._action_agent.aplan(
|
||||||
intermediate_steps,
|
intermediate_steps,
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
**inputs,
|
**inputs,
|
||||||
@ -1507,7 +1524,7 @@ class AgentExecutor(Chain):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
||||||
output = AgentAction("_Exception", observation, text)
|
output = AgentAction("_Exception", observation, text)
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
observation = await ExceptionTool().arun(
|
observation = await ExceptionTool().arun(
|
||||||
output.tool_input,
|
output.tool_input,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
@ -1561,7 +1578,7 @@ class AgentExecutor(Chain):
|
|||||||
tool = name_to_tool_map[agent_action.tool]
|
tool = name_to_tool_map[agent_action.tool]
|
||||||
return_direct = tool.return_direct
|
return_direct = tool.return_direct
|
||||||
color = color_mapping[agent_action.tool]
|
color = color_mapping[agent_action.tool]
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
if return_direct:
|
if return_direct:
|
||||||
tool_run_kwargs["llm_prefix"] = ""
|
tool_run_kwargs["llm_prefix"] = ""
|
||||||
# We then call the tool on the tool input to get an observation
|
# We then call the tool on the tool input to get an observation
|
||||||
@ -1573,7 +1590,7 @@ class AgentExecutor(Chain):
|
|||||||
**tool_run_kwargs,
|
**tool_run_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self._action_agent.tool_run_logging_kwargs()
|
||||||
observation = await InvalidTool().arun(
|
observation = await InvalidTool().arun(
|
||||||
{
|
{
|
||||||
"requested_tool_name": agent_action.tool,
|
"requested_tool_name": agent_action.tool,
|
||||||
@ -1628,7 +1645,7 @@ class AgentExecutor(Chain):
|
|||||||
)
|
)
|
||||||
iterations += 1
|
iterations += 1
|
||||||
time_elapsed = time.time() - start_time
|
time_elapsed = time.time() - start_time
|
||||||
output = self.agent.return_stopped_response(
|
output = self._action_agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
return self._return(output, intermediate_steps, run_manager=run_manager)
|
return self._return(output, intermediate_steps, run_manager=run_manager)
|
||||||
@ -1680,7 +1697,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
time_elapsed = time.time() - start_time
|
time_elapsed = time.time() - start_time
|
||||||
output = self.agent.return_stopped_response(
|
output = self._action_agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
return await self._areturn(
|
return await self._areturn(
|
||||||
@ -1688,7 +1705,7 @@ class AgentExecutor(Chain):
|
|||||||
)
|
)
|
||||||
except (TimeoutError, asyncio.TimeoutError):
|
except (TimeoutError, asyncio.TimeoutError):
|
||||||
# stop early when interrupted by the async timeout
|
# stop early when interrupted by the async timeout
|
||||||
output = self.agent.return_stopped_response(
|
output = self._action_agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
return await self._areturn(
|
return await self._areturn(
|
||||||
@ -1702,8 +1719,8 @@ class AgentExecutor(Chain):
|
|||||||
agent_action, observation = next_step_output
|
agent_action, observation = next_step_output
|
||||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||||
return_value_key = "output"
|
return_value_key = "output"
|
||||||
if len(self.agent.return_values) > 0:
|
if len(self._action_agent.return_values) > 0:
|
||||||
return_value_key = self.agent.return_values[0]
|
return_value_key = self._action_agent.return_values[0]
|
||||||
# Invalid tools won't be in the map, so we return False.
|
# Invalid tools won't be in the map, so we return False.
|
||||||
if agent_action.tool in name_to_tool_map:
|
if agent_action.tool in name_to_tool_map:
|
||||||
if name_to_tool_map[agent_action.tool].return_direct:
|
if name_to_tool_map[agent_action.tool].return_direct:
|
||||||
|
@ -371,7 +371,7 @@ class AgentExecutorIterator:
|
|||||||
"""
|
"""
|
||||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||||
# this manually constructs agent finish with output key
|
# this manually constructs agent finish with output key
|
||||||
output = self.agent_executor.agent.return_stopped_response(
|
output = self.agent_executor._action_agent.return_stopped_response(
|
||||||
self.agent_executor.early_stopping_method,
|
self.agent_executor.early_stopping_method,
|
||||||
self.intermediate_steps,
|
self.intermediate_steps,
|
||||||
**self.inputs,
|
**self.inputs,
|
||||||
@ -384,7 +384,7 @@ class AgentExecutorIterator:
|
|||||||
the stopped response.
|
the stopped response.
|
||||||
"""
|
"""
|
||||||
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
logger.warning("Stopping agent prematurely due to triggering stop condition")
|
||||||
output = self.agent_executor.agent.return_stopped_response(
|
output = self.agent_executor._action_agent.return_stopped_response(
|
||||||
self.agent_executor.early_stopping_method,
|
self.agent_executor.early_stopping_method,
|
||||||
self.intermediate_steps,
|
self.intermediate_steps,
|
||||||
**self.inputs,
|
**self.inputs,
|
||||||
|
@ -21,6 +21,9 @@ def test_initialize_agent_with_str_agent_type() -> None:
|
|||||||
fake_llm,
|
fake_llm,
|
||||||
"zero-shot-react-description", # type: ignore[arg-type]
|
"zero-shot-react-description", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
assert (
|
||||||
|
agent_executor._action_agent._agent_type
|
||||||
|
== AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||||
|
)
|
||||||
assert isinstance(agent_executor.tags, list)
|
assert isinstance(agent_executor.tags, list)
|
||||||
assert "zero-shot-react-description" in agent_executor.tags
|
assert "zero-shot-react-description" in agent_executor.tags
|
||||||
|
Loading…
Reference in New Issue
Block a user