mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
Feature: AgentExecutor execution time limit (#2399)
`AgentExecutor` already has support for limiting the number of iterations. But the amount of time taken for each iteration can vary quite a bit, so it is difficult to place limits on the execution time. This PR adds a new field `max_execution_time` to the `AgentExecutor` model. When called asynchronously, the agent loop is wrapped in an `asyncio.timeout()` context which triggers the early stopping response if the time limit is reached. When called synchronously, the agent loop checks for both the max_iteration limit and the time limit after each iteration. When used asynchronously `max_execution_time` gives really tight control over the max time for an execution chain. When used synchronously, the chain can unfortunately exceed max_execution_time, but it still gives more control than trying to estimate the number of max_iterations needed to cap the execution time. --------- Co-authored-by: Zachary Jones <zjones@zetaglobal.com>
This commit is contained in:
parent
5b34931948
commit
13d1df2140
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
@ -26,6 +27,7 @@ from langchain.schema import (
|
|||||||
BaseOutputParser,
|
BaseOutputParser,
|
||||||
)
|
)
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
from langchain.utilities.asyncio import asyncio_timeout
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@ -88,7 +90,9 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
"""Return response when agent has been stopped due to max iterations."""
|
"""Return response when agent has been stopped due to max iterations."""
|
||||||
if early_stopping_method == "force":
|
if early_stopping_method == "force":
|
||||||
# `force` just returns a constant string
|
# `force` just returns a constant string
|
||||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
return AgentFinish(
|
||||||
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||||
@ -506,7 +510,9 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""Return response when agent has been stopped due to max iterations."""
|
"""Return response when agent has been stopped due to max iterations."""
|
||||||
if early_stopping_method == "force":
|
if early_stopping_method == "force":
|
||||||
# `force` just returns a constant string
|
# `force` just returns a constant string
|
||||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
return AgentFinish(
|
||||||
|
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||||
|
)
|
||||||
elif early_stopping_method == "generate":
|
elif early_stopping_method == "generate":
|
||||||
# Generate does one final forward pass
|
# Generate does one final forward pass
|
||||||
thoughts = ""
|
thoughts = ""
|
||||||
@ -555,6 +561,7 @@ class AgentExecutor(Chain):
|
|||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
max_iterations: Optional[int] = 15
|
max_iterations: Optional[int] = 15
|
||||||
|
max_execution_time: Optional[float] = None
|
||||||
early_stopping_method: str = "force"
|
early_stopping_method: str = "force"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -633,11 +640,16 @@ class AgentExecutor(Chain):
|
|||||||
"""Lookup tool by name."""
|
"""Lookup tool by name."""
|
||||||
return {tool.name: tool for tool in self.tools}[name]
|
return {tool.name: tool for tool in self.tools}[name]
|
||||||
|
|
||||||
def _should_continue(self, iterations: int) -> bool:
|
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
||||||
if self.max_iterations is None:
|
if self.max_iterations is not None and iterations >= self.max_iterations:
|
||||||
return True
|
return False
|
||||||
else:
|
if (
|
||||||
return iterations < self.max_iterations
|
self.max_execution_time is not None
|
||||||
|
and time_elapsed >= self.max_execution_time
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||||
self.callback_manager.on_agent_finish(
|
self.callback_manager.on_agent_finish(
|
||||||
@ -783,10 +795,12 @@ class AgentExecutor(Chain):
|
|||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||||
)
|
)
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||||
# Let's start tracking the iterations the agent has gone through
|
# Let's start tracking the number of iterations and time elapsed
|
||||||
iterations = 0
|
iterations = 0
|
||||||
|
time_elapsed = 0.0
|
||||||
|
start_time = time.time()
|
||||||
# We now enter the agent loop (until it returns something).
|
# We now enter the agent loop (until it returns something).
|
||||||
while self._should_continue(iterations):
|
while self._should_continue(iterations, time_elapsed):
|
||||||
next_step_output = self._take_next_step(
|
next_step_output = self._take_next_step(
|
||||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||||
)
|
)
|
||||||
@ -801,6 +815,7 @@ class AgentExecutor(Chain):
|
|||||||
if tool_return is not None:
|
if tool_return is not None:
|
||||||
return self._return(tool_return, intermediate_steps)
|
return self._return(tool_return, intermediate_steps)
|
||||||
iterations += 1
|
iterations += 1
|
||||||
|
time_elapsed = time.time() - start_time
|
||||||
output = self.agent.return_stopped_response(
|
output = self.agent.return_stopped_response(
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
)
|
)
|
||||||
@ -815,29 +830,40 @@ class AgentExecutor(Chain):
|
|||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||||
)
|
)
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||||
# Let's start tracking the iterations the agent has gone through
|
# Let's start tracking the number of iterations and time elapsed
|
||||||
iterations = 0
|
iterations = 0
|
||||||
|
time_elapsed = 0.0
|
||||||
|
start_time = time.time()
|
||||||
# We now enter the agent loop (until it returns something).
|
# We now enter the agent loop (until it returns something).
|
||||||
while self._should_continue(iterations):
|
async with asyncio_timeout(self.max_execution_time):
|
||||||
next_step_output = await self._atake_next_step(
|
try:
|
||||||
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
while self._should_continue(iterations, time_elapsed):
|
||||||
)
|
next_step_output = await self._atake_next_step(
|
||||||
if isinstance(next_step_output, AgentFinish):
|
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
||||||
return await self._areturn(next_step_output, intermediate_steps)
|
)
|
||||||
|
if isinstance(next_step_output, AgentFinish):
|
||||||
|
return await self._areturn(next_step_output, intermediate_steps)
|
||||||
|
|
||||||
intermediate_steps.extend(next_step_output)
|
intermediate_steps.extend(next_step_output)
|
||||||
if len(next_step_output) == 1:
|
if len(next_step_output) == 1:
|
||||||
next_step_action = next_step_output[0]
|
next_step_action = next_step_output[0]
|
||||||
# See if tool should return directly
|
# See if tool should return directly
|
||||||
tool_return = self._get_tool_return(next_step_action)
|
tool_return = self._get_tool_return(next_step_action)
|
||||||
if tool_return is not None:
|
if tool_return is not None:
|
||||||
return await self._areturn(tool_return, intermediate_steps)
|
return await self._areturn(tool_return, intermediate_steps)
|
||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
output = self.agent.return_stopped_response(
|
time_elapsed = time.time() - start_time
|
||||||
self.early_stopping_method, intermediate_steps, **inputs
|
output = self.agent.return_stopped_response(
|
||||||
)
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
return await self._areturn(output, intermediate_steps)
|
)
|
||||||
|
return await self._areturn(output, intermediate_steps)
|
||||||
|
except TimeoutError:
|
||||||
|
# stop early when interrupted by the async timeout
|
||||||
|
output = self.agent.return_stopped_response(
|
||||||
|
self.early_stopping_method, intermediate_steps, **inputs
|
||||||
|
)
|
||||||
|
return await self._areturn(output, intermediate_steps)
|
||||||
|
|
||||||
def _get_tool_return(
|
def _get_tool_return(
|
||||||
self, next_step_output: Tuple[AgentAction, str]
|
self, next_step_output: Tuple[AgentAction, str]
|
||||||
|
11
langchain/utilities/asyncio.py
Normal file
11
langchain/utilities/asyncio.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
"""Shims for asyncio features that may be missing from older python versions"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info[:2] < (3, 11):
|
||||||
|
from async_timeout import timeout as asyncio_timeout
|
||||||
|
else:
|
||||||
|
from asyncio import timeout as asyncio_timeout
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["asyncio_timeout"]
|
@ -57,6 +57,7 @@ pgvector = {version = "^0.1.6", optional = true}
|
|||||||
psycopg2-binary = {version = "^2.9.5", optional = true}
|
psycopg2-binary = {version = "^2.9.5", optional = true}
|
||||||
boto3 = {version = "^1.26.96", optional = true}
|
boto3 = {version = "^1.26.96", optional = true}
|
||||||
pyowm = {version = "^3.3.0", optional = true}
|
pyowm = {version = "^3.3.0", optional = true}
|
||||||
|
async-timeout = {version = "^4.0.0", python = "<3.11"}
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
autodoc_pydantic = "^1.8.0"
|
autodoc_pydantic = "^1.8.0"
|
||||||
|
@ -70,10 +70,16 @@ def test_agent_bad_action() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_agent_stopped_early() -> None:
|
def test_agent_stopped_early() -> None:
|
||||||
"""Test react chain when bad action given."""
|
"""Test react chain when max iterations or max execution time is exceeded."""
|
||||||
|
# iteration limit
|
||||||
agent = _get_agent(max_iterations=0)
|
agent = _get_agent(max_iterations=0)
|
||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "Agent stopped due to max iterations."
|
assert output == "Agent stopped due to iteration limit or time limit."
|
||||||
|
|
||||||
|
# execution time limit
|
||||||
|
agent = _get_agent(max_execution_time=0.0)
|
||||||
|
output = agent.run("when was langchain made")
|
||||||
|
assert output == "Agent stopped due to iteration limit or time limit."
|
||||||
|
|
||||||
|
|
||||||
def test_agent_with_callbacks_global() -> None:
|
def test_agent_with_callbacks_global() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user