Feature: AgentExecutor execution time limit (#2399)

`AgentExecutor` already has support for limiting the number of
iterations. But the amount of time taken for each iteration can vary
quite a bit, so it is difficult to place limits on the execution time.
This PR adds a new field `max_execution_time` to the `AgentExecutor`
model. When called asynchronously, the agent loop is wrapped in an
`asyncio.timeout()` context which triggers the early stopping response
if the time limit is reached. When called synchronously, the agent loop
checks for both the max_iteration limit and the time limit after each
iteration.

When used asynchronously `max_execution_time` gives really tight control
over the max time for an execution chain. When used synchronously, the
chain can unfortunately exceed max_execution_time, but it still gives
more control than trying to estimate the number of max_iterations needed
to cap the execution time.

---------

Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
Zach Jones 2023-04-06 15:54:32 -04:00 committed by GitHub
parent 5b34931948
commit 13d1df2140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 30 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import logging import logging
import time
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@ -26,6 +27,7 @@ from langchain.schema import (
BaseOutputParser, BaseOutputParser,
) )
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
logger = logging.getLogger() logger = logging.getLogger()
@ -88,7 +90,9 @@ class BaseSingleActionAgent(BaseModel):
"""Return response when agent has been stopped due to max iterations.""" """Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "") return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
else: else:
raise ValueError( raise ValueError(
f"Got unsupported early_stopping_method `{early_stopping_method}`" f"Got unsupported early_stopping_method `{early_stopping_method}`"
@ -506,7 +510,9 @@ class Agent(BaseSingleActionAgent):
"""Return response when agent has been stopped due to max iterations.""" """Return response when agent has been stopped due to max iterations."""
if early_stopping_method == "force": if early_stopping_method == "force":
# `force` just returns a constant string # `force` just returns a constant string
return AgentFinish({"output": "Agent stopped due to max iterations."}, "") return AgentFinish(
{"output": "Agent stopped due to iteration limit or time limit."}, ""
)
elif early_stopping_method == "generate": elif early_stopping_method == "generate":
# Generate does one final forward pass # Generate does one final forward pass
thoughts = "" thoughts = ""
@ -555,6 +561,7 @@ class AgentExecutor(Chain):
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
max_iterations: Optional[int] = 15 max_iterations: Optional[int] = 15
max_execution_time: Optional[float] = None
early_stopping_method: str = "force" early_stopping_method: str = "force"
@classmethod @classmethod
@ -633,11 +640,16 @@ class AgentExecutor(Chain):
"""Lookup tool by name.""" """Lookup tool by name."""
return {tool.name: tool for tool in self.tools}[name] return {tool.name: tool for tool in self.tools}[name]
def _should_continue(self, iterations: int) -> bool: def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is None: if self.max_iterations is not None and iterations >= self.max_iterations:
return True return False
else: if (
return iterations < self.max_iterations self.max_execution_time is not None
and time_elapsed >= self.max_execution_time
):
return False
return True
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
self.callback_manager.on_agent_finish( self.callback_manager.on_agent_finish(
@ -783,10 +795,12 @@ class AgentExecutor(Chain):
[tool.name for tool in self.tools], excluded_colors=["green"] [tool.name for tool in self.tools], excluded_colors=["green"]
) )
intermediate_steps: List[Tuple[AgentAction, str]] = [] intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through # Let's start tracking the number of iterations and time elapsed
iterations = 0 iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something). # We now enter the agent loop (until it returns something).
while self._should_continue(iterations): while self._should_continue(iterations, time_elapsed):
next_step_output = self._take_next_step( next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps name_to_tool_map, color_mapping, inputs, intermediate_steps
) )
@ -801,6 +815,7 @@ class AgentExecutor(Chain):
if tool_return is not None: if tool_return is not None:
return self._return(tool_return, intermediate_steps) return self._return(tool_return, intermediate_steps)
iterations += 1 iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response( output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs self.early_stopping_method, intermediate_steps, **inputs
) )
@ -815,29 +830,40 @@ class AgentExecutor(Chain):
[tool.name for tool in self.tools], excluded_colors=["green"] [tool.name for tool in self.tools], excluded_colors=["green"]
) )
intermediate_steps: List[Tuple[AgentAction, str]] = [] intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through # Let's start tracking the number of iterations and time elapsed
iterations = 0 iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something). # We now enter the agent loop (until it returns something).
while self._should_continue(iterations): async with asyncio_timeout(self.max_execution_time):
next_step_output = await self._atake_next_step( try:
name_to_tool_map, color_mapping, inputs, intermediate_steps while self._should_continue(iterations, time_elapsed):
) next_step_output = await self._atake_next_step(
if isinstance(next_step_output, AgentFinish): name_to_tool_map, color_mapping, inputs, intermediate_steps
return await self._areturn(next_step_output, intermediate_steps) )
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
intermediate_steps.extend(next_step_output) intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1: if len(next_step_output) == 1:
next_step_action = next_step_output[0] next_step_action = next_step_output[0]
# See if tool should return directly # See if tool should return directly
tool_return = self._get_tool_return(next_step_action) tool_return = self._get_tool_return(next_step_action)
if tool_return is not None: if tool_return is not None:
return await self._areturn(tool_return, intermediate_steps) return await self._areturn(tool_return, intermediate_steps)
iterations += 1 iterations += 1
output = self.agent.return_stopped_response( time_elapsed = time.time() - start_time
self.early_stopping_method, intermediate_steps, **inputs output = self.agent.return_stopped_response(
) self.early_stopping_method, intermediate_steps, **inputs
return await self._areturn(output, intermediate_steps) )
return await self._areturn(output, intermediate_steps)
except TimeoutError:
# stop early when interrupted by the async timeout
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
def _get_tool_return( def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str] self, next_step_output: Tuple[AgentAction, str]

View File

@ -0,0 +1,11 @@
"""Shims for asyncio features that may be missing from older python versions"""
import sys
if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
else:
from asyncio import timeout as asyncio_timeout
__all__ = ["asyncio_timeout"]

View File

@ -57,6 +57,7 @@ pgvector = {version = "^0.1.6", optional = true}
psycopg2-binary = {version = "^2.9.5", optional = true} psycopg2-binary = {version = "^2.9.5", optional = true}
boto3 = {version = "^1.26.96", optional = true} boto3 = {version = "^1.26.96", optional = true}
pyowm = {version = "^3.3.0", optional = true} pyowm = {version = "^3.3.0", optional = true}
async-timeout = {version = "^4.0.0", python = "<3.11"}
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0" autodoc_pydantic = "^1.8.0"

View File

@ -70,10 +70,16 @@ def test_agent_bad_action() -> None:
def test_agent_stopped_early() -> None: def test_agent_stopped_early() -> None:
"""Test react chain when bad action given.""" """Test react chain when max iterations or max execution time is exceeded."""
# iteration limit
agent = _get_agent(max_iterations=0) agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations." assert output == "Agent stopped due to iteration limit or time limit."
# execution time limit
agent = _get_agent(max_execution_time=0.0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to iteration limit or time limit."
def test_agent_with_callbacks_global() -> None: def test_agent_with_callbacks_global() -> None: