mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
langchain[patch]: Extract _aperform_agent_action from _aiter_next_step from AgentExecutor (#15707)
- **Description:** extreact the _aperform_agent_action in the AgentExecutor class to allow for easier overriding. Extracted logic from _iter_next_step into a new method _perform_agent_action for consistency and easier overriding. - **Issue:** #15706 Closes #15706
This commit is contained in:
parent
95ee69a301
commit
c69f599594
@ -1179,37 +1179,48 @@ class AgentExecutor(Chain):
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
yield AgentStep(action=agent_action, observation=observation)
|
||||
yield self._perform_agent_action(
|
||||
name_to_tool_map, color_mapping, agent_action, run_manager
|
||||
)
|
||||
|
||||
def _perform_agent_action(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
agent_action: AgentAction,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> AgentStep:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return AgentStep(action=agent_action, observation=observation)
|
||||
|
||||
async def _atake_next_step(
|
||||
self,
|
||||
@ -1303,52 +1314,61 @@ class AgentExecutor(Chain):
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
|
||||
async def _aperform_agent_action(
|
||||
agent_action: AgentAction,
|
||||
) -> AgentStep:
|
||||
if run_manager:
|
||||
await run_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await tool.arun(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = await InvalidTool().arun(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return AgentStep(action=agent_action, observation=observation)
|
||||
|
||||
# Use asyncio.gather to run multiple tool.arun() calls concurrently
|
||||
result = await asyncio.gather(
|
||||
*[_aperform_agent_action(agent_action) for agent_action in actions]
|
||||
*[
|
||||
self._aperform_agent_action(
|
||||
name_to_tool_map, color_mapping, agent_action, run_manager
|
||||
)
|
||||
for agent_action in actions
|
||||
],
|
||||
)
|
||||
|
||||
# TODO This could yield each result as it becomes available
|
||||
for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def _aperform_agent_action(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
agent_action: AgentAction,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> AgentStep:
|
||||
if run_manager:
|
||||
await run_manager.on_agent_action(
|
||||
agent_action, verbose=self.verbose, color="green"
|
||||
)
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await tool.arun(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = await InvalidTool().arun(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
return AgentStep(action=agent_action, observation=observation)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
|
Loading…
Reference in New Issue
Block a user