mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Use a non-inheritable tag
This commit is contained in:
parent
85088dc5df
commit
af2e4ce2cd
@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config={**self.config, **(config or {}), **kwargs},
|
config={**self.config, **(config or {}), **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound.with_retry(retry),
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=self.config,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
|
@ -98,7 +98,6 @@ def patch_config(
|
|||||||
recursion_limit: Optional[int] = None,
|
recursion_limit: Optional[int] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
tags: Optional[List[str]] = None,
|
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if deep_copy_locals:
|
if deep_copy_locals:
|
||||||
@ -115,8 +114,6 @@ def patch_config(
|
|||||||
config["max_concurrency"] = max_concurrency
|
config["max_concurrency"] = max_concurrency
|
||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
config["run_name"] = run_name
|
config["run_name"] = run_name
|
||||||
if tags is not None:
|
|
||||||
config["tags"] = tags
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,20 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||||
|
|
||||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManager as AsyncCallbackManagerT,
|
||||||
|
CallbackManager as CallbackManagerT,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
|
||||||
|
else:
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||||
"""Retry a Runnable if it fails."""
|
"""Retry a Runnable if it fails."""
|
||||||
@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
retry_state: RetryCallState,
|
retry_state: RetryCallState,
|
||||||
|
cm_cls: Type[T],
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
original_tags = config.get("tags") or []
|
return (
|
||||||
return patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
tags=original_tags
|
callbacks=cm_cls.configure(
|
||||||
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if retry_state.attempt_number > 1
|
||||||
|
else config
|
||||||
)
|
)
|
||||||
|
|
||||||
def _patch_config_list(
|
def _patch_config_list(
|
||||||
self,
|
self,
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||||
retry_state: RetryCallState,
|
retry_state: RetryCallState,
|
||||||
|
cm_cls: Type[T],
|
||||||
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
return [self._patch_config(c, retry_state) for c in config]
|
return [self._patch_config(c, retry_state, cm_cls) for c in config]
|
||||||
|
|
||||||
return self._patch_config(config, retry_state)
|
return self._patch_config(config, retry_state, cm_cls)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
for attempt in self._sync_retrying():
|
for attempt in self._sync_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
result = super().invoke(
|
result = super().invoke(
|
||||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
input,
|
||||||
|
self._patch_config(config, attempt.retry_state, CallbackManager),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
async for attempt in self._async_retrying():
|
async for attempt in self._async_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
result = await super().ainvoke(
|
result = await super().ainvoke(
|
||||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
input,
|
||||||
|
self._patch_config(
|
||||||
|
config, attempt.retry_state, AsyncCallbackManager
|
||||||
|
),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
for attempt in self._sync_retrying():
|
for attempt in self._sync_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
result = super().batch(
|
result = super().batch(
|
||||||
inputs,
|
inputs,
|
||||||
self._patch_config_list(config, attempt.retry_state),
|
self._patch_config_list(
|
||||||
|
config, attempt.retry_state, CallbackManager
|
||||||
|
),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
async for attempt in self._async_retrying():
|
async for attempt in self._async_retrying():
|
||||||
with attempt:
|
with attempt:
|
||||||
result = await super().abatch(
|
result = await super().abatch(
|
||||||
inputs,
|
inputs,
|
||||||
self._patch_config_list(config, attempt.retry_state),
|
self._patch_config_list(
|
||||||
|
config, attempt.retry_state, AsyncCallbackManager
|
||||||
|
),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
|
Loading…
Reference in New Issue
Block a user