From af2e4ce2cd063c1e8fc485a978dbd2e1ecbfe0c9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 25 Aug 2023 17:14:40 +0200 Subject: [PATCH] Use a non-inheritable tag --- .../langchain/schema/runnable/base.py | 7 +++ .../langchain/schema/runnable/config.py | 3 - .../langchain/schema/runnable/retry.py | 59 +++++++++++++++---- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7b180628465..ad5c8cfe847 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config={**self.config, **(config or {}), **kwargs}, ) + def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound.with_retry(retry), + kwargs=self.kwargs, + config=self.config, + ) + def invoke( self, input: Input, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 5752b09bf21..3f87f044039 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -98,7 +98,6 @@ def patch_config( recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, - tags: Optional[List[str]] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -115,8 +114,6 @@ def patch_config( config["max_concurrency"] = max_concurrency if run_name is not None: config["run_name"] = run_name - if tags is not None: - config["tags"] = tags return config diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index c53621904fd..b746d874286 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,10 +1,20 @@ -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.config import RunnableConfig, patch_config +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManager as AsyncCallbackManagerT, + CallbackManager as CallbackManagerT, + ) + + T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT) +else: + T = TypeVar("T") + class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" @@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]): self, config: Optional[RunnableConfig], retry_state: RetryCallState, + cm_cls: Type[T], ) -> RunnableConfig: config = config or {} - original_tags = config.get("tags") or [] - return patch_config( - config, - tags=original_tags - + ["retry:attempt:{}".format(retry_state.attempt_number)], + return ( + patch_config( + config, + callbacks=cm_cls.configure( + inheritable_callbacks=config.get("callbacks"), + local_tags=["retry:attempt:{}".format(retry_state.attempt_number)], + ), + ) + if retry_state.attempt_number > 1 + else config ) def _patch_config_list( self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], retry_state: RetryCallState, + cm_cls: Type[T], ) -> Union[RunnableConfig, List[RunnableConfig]]: if isinstance(config, list): - return [self._patch_config(c, retry_state) for c in config] + return [self._patch_config(c, retry_state, cm_cls) for c in config] - return self._patch_config(config, retry_state) + return self._patch_config(config, retry_state, cm_cls) def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: + from langchain.callbacks.manager import CallbackManager + for attempt in self._sync_retrying(): with attempt: result = super().invoke( - input, self._patch_config(config, attempt.retry_state), **kwargs + input, + self._patch_config(config, attempt.retry_state, CallbackManager), + **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + async for attempt in self._async_retrying(): with attempt: result = await super().ainvoke( - input, self._patch_config(config, attempt.retry_state), **kwargs + input, + self._patch_config( + config, attempt.retry_state, AsyncCallbackManager + ), + **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Any ) -> List[Output]: + from langchain.callbacks.manager import CallbackManager + for attempt in self._sync_retrying(): with attempt: result = super().batch( inputs, - self._patch_config_list(config, attempt.retry_state), + self._patch_config_list( + config, attempt.retry_state, CallbackManager + ), **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: @@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Any ) -> List[Output]: + from langchain.callbacks.manager import AsyncCallbackManager + async for attempt in self._async_retrying(): with attempt: result = await super().abatch( inputs, - self._patch_config_list(config, attempt.retry_state), + self._patch_config_list( + config, attempt.retry_state, AsyncCallbackManager + ), **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: