mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
Add .with_retry() to Runnables
This commit is contained in:
parent
50a5c5bcf8
commit
b2ac835466
@ -27,6 +27,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from tenacity import BaseRetrying
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -226,6 +228,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||
)
|
||||
|
||||
def with_retry(
|
||||
self,
|
||||
retry: BaseRetrying,
|
||||
) -> Runnable[Input, Output]:
|
||||
from langchain.schema.runnable.retry import RunnableRetry
|
||||
|
||||
return RunnableRetry(bound=self, retry=retry, kwargs={}, config={})
|
||||
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||
|
@ -98,6 +98,7 @@ def patch_config(
|
||||
recursion_limit: Optional[int] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
if deep_copy_locals:
|
||||
@ -114,6 +115,8 @@ def patch_config(
|
||||
config["max_concurrency"] = max_concurrency
|
||||
if run_name is not None:
|
||||
config["run_name"] = run_name
|
||||
if tags is not None:
|
||||
config["tags"] = tags
|
||||
return config
|
||||
|
||||
|
||||
|
111
libs/langchain/langchain/schema/runnable/retry.py
Normal file
111
libs/langchain/langchain/schema/runnable/retry.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
"""Retry a Runnable if it fails."""
|
||||
|
||||
retry: BaseRetrying
|
||||
|
||||
def _sync_retrying(self) -> Retrying:
|
||||
return Retrying(
|
||||
sleep=self.retry.sleep,
|
||||
stop=self.retry.stop,
|
||||
wait=self.retry.wait,
|
||||
retry=self.retry.retry,
|
||||
before=self.retry.before,
|
||||
after=self.retry.after,
|
||||
before_sleep=self.retry.before_sleep,
|
||||
reraise=self.retry.reraise,
|
||||
retry_error_cls=self.retry.retry_error_cls,
|
||||
retry_error_callback=self.retry.retry_error_callback,
|
||||
)
|
||||
|
||||
def _async_retrying(self) -> AsyncRetrying:
|
||||
return AsyncRetrying(
|
||||
sleep=self.retry.sleep,
|
||||
stop=self.retry.stop,
|
||||
wait=self.retry.wait,
|
||||
retry=self.retry.retry,
|
||||
before=self.retry.before,
|
||||
after=self.retry.after,
|
||||
before_sleep=self.retry.before_sleep,
|
||||
reraise=self.retry.reraise,
|
||||
retry_error_cls=self.retry.retry_error_cls,
|
||||
retry_error_callback=self.retry.retry_error_callback,
|
||||
)
|
||||
|
||||
def _patch_config(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
retry_state: RetryCallState,
|
||||
) -> RunnableConfig:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state) for c in config]
|
||||
|
||||
config = config or {}
|
||||
original_tags = config.get("tags") or []
|
||||
return patch_config(
|
||||
config,
|
||||
tags=original_tags
|
||||
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None
|
||||
) -> Output:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None
|
||||
) -> Output:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().batch(
|
||||
inputs, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().abatch(
|
||||
inputs, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
@ -41,6 +41,7 @@ from langchain.schema.runnable import (
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
@ -141,7 +142,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
else:
|
||||
assert call.args[2].get("tags") == ["b-tag"]
|
||||
assert call.args[2].get("max_concurrency") == 5
|
||||
spy_seq_step.reset_mock()
|
||||
mocker.stop(spy_seq_step)
|
||||
|
||||
assert [
|
||||
*fake.with_config(tags=["a-tag"]).stream(
|
||||
@ -1423,3 +1424,42 @@ def test_recursive_lambda() -> None:
|
||||
|
||||
with pytest.raises(RecursionError):
|
||||
runnable.invoke(0, {"recursion_limit": 9})
|
||||
|
||||
|
||||
def test_retrying(mocker: MockerFixture) -> None:
|
||||
def _lambda(x: int) -> Union[int, Runnable]:
|
||||
if x == 1:
|
||||
raise ValueError("x is 1")
|
||||
elif x == 2:
|
||||
raise RuntimeError("x is 2")
|
||||
else:
|
||||
return x
|
||||
|
||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||
runnable = RunnableLambda(_lambda_mock)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable.invoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
runnable.with_retry(
|
||||
Retrying(
|
||||
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
|
||||
)
|
||||
).invoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 2
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable.with_retry(
|
||||
Retrying(
|
||||
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
|
||||
)
|
||||
).invoke(2)
|
||||
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
Loading…
Reference in New Issue
Block a user