Use a non-inheritable tag

This commit is contained in:
Nuno Campos 2023-08-25 17:14:40 +02:00
parent 85088dc5df
commit af2e4ce2cd
3 changed files with 54 additions and 15 deletions

View File

@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config={**self.config, **(config or {}), **kwargs}, config={**self.config, **(config or {}), **kwargs},
) )
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(retry),
kwargs=self.kwargs,
config=self.config,
)
def invoke( def invoke(
self, self,
input: Input, input: Input,

View File

@ -98,7 +98,6 @@ def patch_config(
recursion_limit: Optional[int] = None, recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if deep_copy_locals: if deep_copy_locals:
@ -115,8 +114,6 @@ def patch_config(
config["max_concurrency"] = max_concurrency config["max_concurrency"] = max_concurrency
if run_name is not None: if run_name is not None:
config["run_name"] = run_name config["run_name"] = run_name
if tags is not None:
config["tags"] = tags
return config return config

View File

@ -1,10 +1,20 @@
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.config import RunnableConfig, patch_config
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManager as AsyncCallbackManagerT,
CallbackManager as CallbackManagerT,
)
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
else:
T = TypeVar("T")
class RunnableRetry(RunnableBinding[Input, Output]): class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails.""" """Retry a Runnable if it fails."""
@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]):
self, self,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
retry_state: RetryCallState, retry_state: RetryCallState,
cm_cls: Type[T],
) -> RunnableConfig: ) -> RunnableConfig:
config = config or {} config = config or {}
original_tags = config.get("tags") or [] return (
return patch_config( patch_config(
config, config,
tags=original_tags callbacks=cm_cls.configure(
+ ["retry:attempt:{}".format(retry_state.attempt_number)], inheritable_callbacks=config.get("callbacks"),
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
),
)
if retry_state.attempt_number > 1
else config
) )
def _patch_config_list( def _patch_config_list(
self, self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
retry_state: RetryCallState, retry_state: RetryCallState,
cm_cls: Type[T],
) -> Union[RunnableConfig, List[RunnableConfig]]: ) -> Union[RunnableConfig, List[RunnableConfig]]:
if isinstance(config, list): if isinstance(config, list):
return [self._patch_config(c, retry_state) for c in config] return [self._patch_config(c, retry_state, cm_cls) for c in config]
return self._patch_config(config, retry_state) return self._patch_config(config, retry_state, cm_cls)
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
from langchain.callbacks.manager import CallbackManager
for attempt in self._sync_retrying(): for attempt in self._sync_retrying():
with attempt: with attempt:
result = super().invoke( result = super().invoke(
input, self._patch_config(config, attempt.retry_state), **kwargs input,
self._patch_config(config, attempt.retry_state, CallbackManager),
**kwargs
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
async for attempt in self._async_retrying(): async for attempt in self._async_retrying():
with attempt: with attempt:
result = await super().ainvoke( result = await super().ainvoke(
input, self._patch_config(config, attempt.retry_state), **kwargs input,
self._patch_config(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any **kwargs: Any
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
for attempt in self._sync_retrying(): for attempt in self._sync_retrying():
with attempt: with attempt:
result = super().batch( result = super().batch(
inputs, inputs,
self._patch_config_list(config, attempt.retry_state), self._patch_config_list(
config, attempt.retry_state, CallbackManager
),
**kwargs **kwargs
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any **kwargs: Any
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
async for attempt in self._async_retrying(): async for attempt in self._async_retrying():
with attempt: with attempt:
result = await super().abatch( result = await super().abatch(
inputs, inputs,
self._patch_config_list(config, attempt.retry_state), self._patch_config_list(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs **kwargs
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: