mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-02 13:08:57 +00:00
Feature: AgentExecutor execution time limit (#2399)
`AgentExecutor` already has support for limiting the number of iterations. But the amount of time taken for each iteration can vary quite a bit, so it is difficult to place limits on the execution time. This PR adds a new field `max_execution_time` to the `AgentExecutor` model. When called asynchronously, the agent loop is wrapped in an `asyncio.timeout()` context which triggers the early stopping response if the time limit is reached. When called synchronously, the agent loop checks for both the max_iteration limit and the time limit after each iteration. When used asynchronously `max_execution_time` gives really tight control over the max time for an execution chain. When used synchronously, the chain can unfortunately exceed max_execution_time, but it still gives more control than trying to estimate the number of max_iterations needed to cap the execution time. --------- Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
parent
5b34931948
commit
13d1df2140
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
@ -26,6 +27,7 @@ from langchain.schema import (
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@ -88,7 +90,9 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
@ -506,7 +510,9 @@ class Agent(BaseSingleActionAgent):
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
elif early_stopping_method == "generate":
|
||||
# Generate does one final forward pass
|
||||
thoughts = ""
|
||||
@ -555,6 +561,7 @@ class AgentExecutor(Chain):
|
||||
tools: Sequence[BaseTool]
|
||||
return_intermediate_steps: bool = False
|
||||
max_iterations: Optional[int] = 15
|
||||
max_execution_time: Optional[float] = None
|
||||
early_stopping_method: str = "force"
|
||||
|
||||
@classmethod
|
||||
@ -633,11 +640,16 @@ class AgentExecutor(Chain):
|
||||
"""Lookup tool by name."""
|
||||
return {tool.name: tool for tool in self.tools}[name]
|
||||
|
||||
def _should_continue(self, iterations: int) -> bool:
|
||||
if self.max_iterations is None:
|
||||
return True
|
||||
else:
|
||||
return iterations < self.max_iterations
|
||||
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
||||
if self.max_iterations is not None and iterations >= self.max_iterations:
|
||||
return False
|
||||
if (
|
||||
self.max_execution_time is not None
|
||||
and time_elapsed >= self.max_execution_time
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||
self.callback_manager.on_agent_finish(
|
||||
@ -783,10 +795,12 @@ class AgentExecutor(Chain):
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the iterations the agent has gone through
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
iterations = 0
|
||||
time_elapsed = 0.0
|
||||
start_time = time.time()
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations):
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = self._take_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
@ -801,6 +815,7 @@ class AgentExecutor(Chain):
|
||||
if tool_return is not None:
|
||||
return self._return(tool_return, intermediate_steps)
|
||||
iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
@ -815,29 +830,40 @@ class AgentExecutor(Chain):
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the iterations the agent has gone through
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
iterations = 0
|
||||
time_elapsed = 0.0
|
||||
start_time = time.time()
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(iterations):
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
async with asyncio_timeout(self.max_execution_time):
|
||||
try:
|
||||
while self._should_continue(iterations, time_elapsed):
|
||||
next_step_output = await self._atake_next_step(
|
||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return await self._areturn(next_step_output, intermediate_steps)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
# See if tool should return directly
|
||||
tool_return = self._get_tool_return(next_step_action)
|
||||
if tool_return is not None:
|
||||
return await self._areturn(tool_return, intermediate_steps)
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
# See if tool should return directly
|
||||
tool_return = self._get_tool_return(next_step_action)
|
||||
if tool_return is not None:
|
||||
return await self._areturn(tool_return, intermediate_steps)
|
||||
|
||||
iterations += 1
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
except TimeoutError:
|
||||
# stop early when interrupted by the async timeout
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return await self._areturn(output, intermediate_steps)
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
|
11
langchain/utilities/asyncio.py
Normal file
11
langchain/utilities/asyncio.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Shims for asyncio features that may be missing from older python versions"""
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info[:2] < (3, 11):
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
else:
|
||||
from asyncio import timeout as asyncio_timeout
|
||||
|
||||
|
||||
__all__ = ["asyncio_timeout"]
|
@ -57,6 +57,7 @@ pgvector = {version = "^0.1.6", optional = true}
|
||||
psycopg2-binary = {version = "^2.9.5", optional = true}
|
||||
boto3 = {version = "^1.26.96", optional = true}
|
||||
pyowm = {version = "^3.3.0", optional = true}
|
||||
async-timeout = {version = "^4.0.0", python = "<3.11"}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
autodoc_pydantic = "^1.8.0"
|
||||
|
@ -70,10 +70,16 @@ def test_agent_bad_action() -> None:
|
||||
|
||||
|
||||
def test_agent_stopped_early() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
"""Test react chain when max iterations or max execution time is exceeded."""
|
||||
# iteration limit
|
||||
agent = _get_agent(max_iterations=0)
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to max iterations."
|
||||
assert output == "Agent stopped due to iteration limit or time limit."
|
||||
|
||||
# execution time limit
|
||||
agent = _get_agent(max_execution_time=0.0)
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to iteration limit or time limit."
|
||||
|
||||
|
||||
def test_agent_with_callbacks_global() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user