mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
Nc/runnables retry (#9711)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
b1c87da2b0
@ -263,20 +263,28 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
self,
|
self,
|
||||||
inputs: List[LanguageModelInput],
|
inputs: List[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
config = get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
|
max_concurrency = config[0].get("max_concurrency")
|
||||||
|
|
||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
llm_result = self.generate_prompt(
|
try:
|
||||||
[self._convert_input(input) for input in inputs],
|
llm_result = self.generate_prompt(
|
||||||
callbacks=[c.get("callbacks") for c in config],
|
[self._convert_input(input) for input in inputs],
|
||||||
tags=[c.get("tags") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
**kwargs,
|
metadata=[c.get("metadata") for c in config],
|
||||||
)
|
**kwargs,
|
||||||
return [g[0].text for g in llm_result.generations]
|
)
|
||||||
|
return [g[0].text for g in llm_result.generations]
|
||||||
|
except Exception as e:
|
||||||
|
if return_exceptions:
|
||||||
|
return cast(List[str], [e for _ in inputs])
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
batches = [
|
batches = [
|
||||||
inputs[i : i + max_concurrency]
|
inputs[i : i + max_concurrency]
|
||||||
@ -285,33 +293,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return [
|
return [
|
||||||
output
|
output
|
||||||
for batch in batches
|
for batch in batches
|
||||||
for output in self.batch(batch, config=config, **kwargs)
|
for output in self.batch(
|
||||||
|
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[LanguageModelInput],
|
inputs: List[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
# model doesn't implement async batch, so use default implementation
|
# model doesn't implement async batch, so use default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None, self.batch, inputs, config, max_concurrency
|
None, partial(self.batch, **kwargs), inputs, config
|
||||||
)
|
)
|
||||||
|
|
||||||
config = get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
|
max_concurrency = config[0].get("max_concurrency")
|
||||||
|
|
||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
llm_result = await self.agenerate_prompt(
|
try:
|
||||||
[self._convert_input(input) for input in inputs],
|
llm_result = await self.agenerate_prompt(
|
||||||
callbacks=[c.get("callbacks") for c in config],
|
[self._convert_input(input) for input in inputs],
|
||||||
tags=[c.get("tags") for c in config],
|
callbacks=[c.get("callbacks") for c in config],
|
||||||
metadata=[c.get("metadata") for c in config],
|
tags=[c.get("tags") for c in config],
|
||||||
**kwargs,
|
metadata=[c.get("metadata") for c in config],
|
||||||
)
|
**kwargs,
|
||||||
return [g[0].text for g in llm_result.generations]
|
)
|
||||||
|
return [g[0].text for g in llm_result.generations]
|
||||||
|
except Exception as e:
|
||||||
|
if return_exceptions:
|
||||||
|
return cast(List[str], [e for _ in inputs])
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
batches = [
|
batches = [
|
||||||
inputs[i : i + max_concurrency]
|
inputs[i : i + max_concurrency]
|
||||||
|
@ -105,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""
|
"""
|
||||||
@ -113,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return self.invoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return self.invoke(input, config, **kwargs)
|
||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
return cast(List[Output], [invoke(inputs[0], configs[0])])
|
||||||
|
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
|
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""
|
"""
|
||||||
@ -131,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Subclasses should override this method if they can batch more efficiently.
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
"""
|
"""
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
input: Input, config: RunnableConfig
|
||||||
|
) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return await self.ainvoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return await self.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
coros = map(ainvoke, inputs, configs)
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
@ -226,6 +250,24 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def with_retry(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,),
|
||||||
|
wait_exponential_jitter: bool = True,
|
||||||
|
stop_after_attempt: int = 3,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
from langchain.schema.runnable.retry import RunnableRetry
|
||||||
|
|
||||||
|
return RunnableRetry(
|
||||||
|
bound=self,
|
||||||
|
kwargs={},
|
||||||
|
config={},
|
||||||
|
retry_exception_types=retry_if_exception_type,
|
||||||
|
wait_exponential_jitter=wait_exponential_jitter,
|
||||||
|
max_attempt_number=stop_after_attempt,
|
||||||
|
)
|
||||||
|
|
||||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||||
"""
|
"""
|
||||||
Return a new Runnable that maps a list of inputs to a list of outputs,
|
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||||
@ -331,6 +373,145 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await run_manager.on_chain_end(dumpd(output))
|
await run_manager.on_chain_end(dumpd(output))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _batch_with_config(
|
||||||
|
self,
|
||||||
|
func: Union[
|
||||||
|
Callable[[List[Input]], List[Union[Exception, Output]]],
|
||||||
|
Callable[
|
||||||
|
[List[Input], List[CallbackManagerForChainRun]],
|
||||||
|
List[Union[Exception, Output]],
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
|
||||||
|
List[Union[Exception, Output]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
input: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
run_type: Optional[str] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
"""Helper method to transform an Input value to an Output value,
|
||||||
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
|
configs = get_config_list(config, len(input))
|
||||||
|
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||||
|
run_managers = [
|
||||||
|
callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for callback_manager, input, config in zip(
|
||||||
|
callback_managers, input, configs
|
||||||
|
)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
if accepts_run_manager_and_config(func):
|
||||||
|
output = func(
|
||||||
|
input,
|
||||||
|
run_manager=run_managers,
|
||||||
|
config=configs,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(func):
|
||||||
|
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
|
||||||
|
else:
|
||||||
|
output = func(input) # type: ignore[call-arg]
|
||||||
|
except Exception as e:
|
||||||
|
for run_manager in run_managers:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
if return_exceptions:
|
||||||
|
return cast(List[Output], [e for _ in input])
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
first_exception: Optional[Exception] = None
|
||||||
|
for run_manager, out in zip(run_managers, output):
|
||||||
|
if isinstance(out, Exception):
|
||||||
|
first_exception = first_exception or out
|
||||||
|
run_manager.on_chain_error(out)
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(dumpd(out))
|
||||||
|
if return_exceptions or first_exception is None:
|
||||||
|
return cast(List[Output], output)
|
||||||
|
else:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
|
async def _abatch_with_config(
|
||||||
|
self,
|
||||||
|
func: Union[
|
||||||
|
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
|
||||||
|
Callable[
|
||||||
|
[List[Input], List[AsyncCallbackManagerForChainRun]],
|
||||||
|
Awaitable[List[Union[Exception, Output]]],
|
||||||
|
],
|
||||||
|
Callable[
|
||||||
|
[
|
||||||
|
List[Input],
|
||||||
|
List[AsyncCallbackManagerForChainRun],
|
||||||
|
List[RunnableConfig],
|
||||||
|
],
|
||||||
|
Awaitable[List[Union[Exception, Output]]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
input: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
run_type: Optional[str] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
"""Helper method to transform an Input value to an Output value,
|
||||||
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
|
configs = get_config_list(config, len(input))
|
||||||
|
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||||
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for callback_manager, input, config in zip(
|
||||||
|
callback_managers, input, configs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if accepts_run_manager_and_config(func):
|
||||||
|
output = await func(
|
||||||
|
input,
|
||||||
|
run_manager=run_managers,
|
||||||
|
config=configs,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
elif accepts_run_manager(func):
|
||||||
|
output = await func(input, run_manager=run_managers) # type: ignore
|
||||||
|
else:
|
||||||
|
output = await func(input) # type: ignore[call-arg]
|
||||||
|
except Exception as e:
|
||||||
|
await asyncio.gather(
|
||||||
|
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||||
|
)
|
||||||
|
if return_exceptions:
|
||||||
|
return cast(List[Output], [e for _ in input])
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
first_exception: Optional[Exception] = None
|
||||||
|
coros: List[Awaitable[None]] = []
|
||||||
|
for run_manager, out in zip(run_managers, output):
|
||||||
|
if isinstance(out, Exception):
|
||||||
|
first_exception = first_exception or out
|
||||||
|
coros.append(run_manager.on_chain_error(out))
|
||||||
|
else:
|
||||||
|
coros.append(run_manager.on_chain_end(dumpd(out)))
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
if return_exceptions or first_exception is None:
|
||||||
|
return cast(List[Output], output)
|
||||||
|
else:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
def _transform_stream_with_config(
|
def _transform_stream_with_config(
|
||||||
self,
|
self,
|
||||||
input: Iterator[Input],
|
input: Iterator[Input],
|
||||||
@ -586,10 +767,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
@ -646,10 +832,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
@ -831,6 +1022,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
@ -861,29 +1054,88 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
# invoke
|
# invoke
|
||||||
try:
|
try:
|
||||||
for step in self.steps:
|
if return_exceptions:
|
||||||
inputs = step.batch(
|
# Track which inputs (by index) failed so far
|
||||||
inputs,
|
# If an input has failed it will be present in this map,
|
||||||
[
|
# and the value will be the exception that was raised.
|
||||||
# each step a child run of the corresponding root run
|
failed_inputs_map: Dict[int, Exception] = {}
|
||||||
patch_config(config, callbacks=rm.get_child())
|
for step in self.steps:
|
||||||
for rm, config in zip(run_managers, configs)
|
# Assemble the original indexes of the remaining inputs
|
||||||
],
|
# (i.e. the ones that haven't failed yet)
|
||||||
)
|
remaining_idxs = [
|
||||||
|
i for i in range(len(configs)) if i not in failed_inputs_map
|
||||||
|
]
|
||||||
|
# Invoke the step on the remaining inputs
|
||||||
|
inputs = step.batch(
|
||||||
|
[
|
||||||
|
inp
|
||||||
|
for i, inp in zip(remaining_idxs, inputs)
|
||||||
|
if i not in failed_inputs_map
|
||||||
|
],
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for i, (rm, config) in enumerate(zip(run_managers, configs))
|
||||||
|
if i not in failed_inputs_map
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# If an input failed, add it to the map
|
||||||
|
for i, inp in zip(remaining_idxs, inputs):
|
||||||
|
if isinstance(inp, Exception):
|
||||||
|
failed_inputs_map[i] = inp
|
||||||
|
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||||
|
# If all inputs have failed, stop processing
|
||||||
|
if len(failed_inputs_map) == len(configs):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reassemble the outputs, inserting Exceptions for failed inputs
|
||||||
|
inputs_copy = inputs.copy()
|
||||||
|
inputs = []
|
||||||
|
for i in range(len(configs)):
|
||||||
|
if i in failed_inputs_map:
|
||||||
|
inputs.append(cast(Input, failed_inputs_map[i]))
|
||||||
|
else:
|
||||||
|
inputs.append(inputs_copy.pop(0))
|
||||||
|
else:
|
||||||
|
for step in self.steps:
|
||||||
|
inputs = step.batch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
for rm in run_managers:
|
for rm in run_managers:
|
||||||
rm.on_chain_error(e)
|
rm.on_chain_error(e)
|
||||||
raise
|
if return_exceptions:
|
||||||
|
return cast(List[Output], [e for _ in inputs])
|
||||||
|
else:
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
for rm, input in zip(run_managers, inputs):
|
first_exception: Optional[Exception] = None
|
||||||
rm.on_chain_end(input)
|
for run_manager, out in zip(run_managers, inputs):
|
||||||
return cast(List[Output], inputs)
|
if isinstance(out, Exception):
|
||||||
|
first_exception = first_exception or out
|
||||||
|
run_manager.on_chain_error(out)
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(dumpd(out))
|
||||||
|
if return_exceptions or first_exception is None:
|
||||||
|
return cast(List[Output], inputs)
|
||||||
|
else:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -919,24 +1171,81 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
# invoke .batch() on each step
|
# invoke .batch() on each step
|
||||||
# this uses batching optimizations in Runnable subclasses, like LLM
|
# this uses batching optimizations in Runnable subclasses, like LLM
|
||||||
try:
|
try:
|
||||||
for step in self.steps:
|
if return_exceptions:
|
||||||
inputs = await step.abatch(
|
# Track which inputs (by index) failed so far
|
||||||
inputs,
|
# If an input has failed it will be present in this map,
|
||||||
[
|
# and the value will be the exception that was raised.
|
||||||
# each step a child run of the corresponding root run
|
failed_inputs_map: Dict[int, Exception] = {}
|
||||||
patch_config(config, callbacks=rm.get_child())
|
for step in self.steps:
|
||||||
for rm, config in zip(run_managers, configs)
|
# Assemble the original indexes of the remaining inputs
|
||||||
],
|
# (i.e. the ones that haven't failed yet)
|
||||||
)
|
remaining_idxs = [
|
||||||
|
i for i in range(len(configs)) if i not in failed_inputs_map
|
||||||
|
]
|
||||||
|
# Invoke the step on the remaining inputs
|
||||||
|
inputs = await step.abatch(
|
||||||
|
[
|
||||||
|
inp
|
||||||
|
for i, inp in zip(remaining_idxs, inputs)
|
||||||
|
if i not in failed_inputs_map
|
||||||
|
],
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for i, (rm, config) in enumerate(zip(run_managers, configs))
|
||||||
|
if i not in failed_inputs_map
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# If an input failed, add it to the map
|
||||||
|
for i, inp in zip(remaining_idxs, inputs):
|
||||||
|
if isinstance(inp, Exception):
|
||||||
|
failed_inputs_map[i] = inp
|
||||||
|
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||||
|
# If all inputs have failed, stop processing
|
||||||
|
if len(failed_inputs_map) == len(configs):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reassemble the outputs, inserting Exceptions for failed inputs
|
||||||
|
inputs_copy = inputs.copy()
|
||||||
|
inputs = []
|
||||||
|
for i in range(len(configs)):
|
||||||
|
if i in failed_inputs_map:
|
||||||
|
inputs.append(cast(Input, failed_inputs_map[i]))
|
||||||
|
else:
|
||||||
|
inputs.append(inputs_copy.pop(0))
|
||||||
|
else:
|
||||||
|
for step in self.steps:
|
||||||
|
inputs = await step.abatch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
)
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
raise
|
if return_exceptions:
|
||||||
|
return cast(List[Output], [e for _ in inputs])
|
||||||
|
else:
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
await asyncio.gather(
|
first_exception: Optional[Exception] = None
|
||||||
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
|
coros: List[Awaitable[None]] = []
|
||||||
)
|
for run_manager, out in zip(run_managers, inputs):
|
||||||
return cast(List[Output], inputs)
|
if isinstance(out, Exception):
|
||||||
|
first_exception = first_exception or out
|
||||||
|
coros.append(run_manager.on_chain_error(out))
|
||||||
|
else:
|
||||||
|
coros.append(run_manager.on_chain_end(dumpd(out)))
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
if return_exceptions or first_exception is None:
|
||||||
|
return cast(List[Output], inputs)
|
||||||
|
else:
|
||||||
|
raise first_exception
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@ -1545,6 +1854,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config={**self.config, **(config or {}), **kwargs},
|
config={**self.config, **(config or {}), **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound.with_retry(**kwargs),
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=self.config,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -1573,6 +1889,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
@ -1584,12 +1902,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
return self.bound.batch(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
@ -1601,7 +1926,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
|
return await self.bound.abatch(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
|
245
libs/langchain/langchain/schema/runnable/retry.py
Normal file
245
libs/langchain/langchain/schema/runnable/retry.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
|
||||||
|
|
||||||
|
from tenacity import (
|
||||||
|
AsyncRetrying,
|
||||||
|
RetryCallState,
|
||||||
|
RetryError,
|
||||||
|
Retrying,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
|
|
||||||
|
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||||
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||||
|
"""Retry a Runnable if it fails."""
|
||||||
|
|
||||||
|
retry_exception_types: Tuple[Type[BaseException]] = (Exception,)
|
||||||
|
|
||||||
|
wait_exponential_jitter: bool = True
|
||||||
|
|
||||||
|
max_attempt_number: int = 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||||
|
kwargs: Dict[str, Any] = dict()
|
||||||
|
|
||||||
|
if self.max_attempt_number:
|
||||||
|
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
|
||||||
|
|
||||||
|
if self.wait_exponential_jitter:
|
||||||
|
kwargs["wait"] = wait_exponential_jitter()
|
||||||
|
|
||||||
|
if self.retry_exception_types:
|
||||||
|
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _sync_retrying(self, **kwargs: Any) -> Retrying:
|
||||||
|
return Retrying(**self._kwargs_retrying, **kwargs)
|
||||||
|
|
||||||
|
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
|
||||||
|
return AsyncRetrying(**self._kwargs_retrying, **kwargs)
|
||||||
|
|
||||||
|
def _patch_config(
|
||||||
|
self,
|
||||||
|
config: RunnableConfig,
|
||||||
|
run_manager: T,
|
||||||
|
retry_state: RetryCallState,
|
||||||
|
) -> RunnableConfig:
|
||||||
|
attempt = retry_state.attempt_number
|
||||||
|
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
|
||||||
|
return patch_config(config, callbacks=run_manager.get_child(tag))
|
||||||
|
|
||||||
|
def _patch_config_list(
|
||||||
|
self,
|
||||||
|
config: List[RunnableConfig],
|
||||||
|
run_manager: List[T],
|
||||||
|
retry_state: RetryCallState,
|
||||||
|
) -> List[RunnableConfig]:
|
||||||
|
return [
|
||||||
|
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: CallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
for attempt in self._sync_retrying(reraise=True):
|
||||||
|
with attempt:
|
||||||
|
result = super().invoke(
|
||||||
|
input,
|
||||||
|
self._patch_config(config, run_manager, attempt.retry_state),
|
||||||
|
)
|
||||||
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
|
attempt.retry_state.set_result(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Output:
|
||||||
|
async for attempt in self._async_retrying(reraise=True):
|
||||||
|
with attempt:
|
||||||
|
result = await super().ainvoke(
|
||||||
|
input,
|
||||||
|
self._patch_config(config, run_manager, attempt.retry_state),
|
||||||
|
)
|
||||||
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||||
|
attempt.retry_state.set_result(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||||
|
|
||||||
|
def _batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
run_manager: List[CallbackManagerForChainRun],
|
||||||
|
config: List[RunnableConfig],
|
||||||
|
) -> List[Union[Output, Exception]]:
|
||||||
|
results_map: Dict[int, Output] = {}
|
||||||
|
|
||||||
|
def pending(iterable: List[U]) -> List[U]:
|
||||||
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for attempt in self._sync_retrying():
|
||||||
|
with attempt:
|
||||||
|
# Get the results of the inputs that have not succeeded yet.
|
||||||
|
result = super().batch(
|
||||||
|
pending(inputs),
|
||||||
|
self._patch_config_list(
|
||||||
|
pending(config), pending(run_manager), attempt.retry_state
|
||||||
|
),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
# Register the results of the inputs that have succeeded.
|
||||||
|
first_exception = None
|
||||||
|
for i, r in enumerate(result):
|
||||||
|
if isinstance(r, Exception):
|
||||||
|
if not first_exception:
|
||||||
|
first_exception = r
|
||||||
|
continue
|
||||||
|
results_map[i] = r
|
||||||
|
# If any exception occurred, raise it, to retry the failed ones
|
||||||
|
if first_exception:
|
||||||
|
raise first_exception
|
||||||
|
if (
|
||||||
|
attempt.retry_state.outcome
|
||||||
|
and not attempt.retry_state.outcome.failed
|
||||||
|
):
|
||||||
|
attempt.retry_state.set_result(result)
|
||||||
|
except RetryError as e:
|
||||||
|
try:
|
||||||
|
result
|
||||||
|
except UnboundLocalError:
|
||||||
|
result = cast(List[Output], [e] * len(inputs))
|
||||||
|
|
||||||
|
outputs: List[Union[Output, Exception]] = []
|
||||||
|
for idx, _ in enumerate(inputs):
|
||||||
|
if idx in results_map:
|
||||||
|
outputs.append(results_map[idx])
|
||||||
|
else:
|
||||||
|
outputs.append(result.pop(0))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> List[Output]:
|
||||||
|
return self._batch_with_config(
|
||||||
|
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
run_manager: List[AsyncCallbackManagerForChainRun],
|
||||||
|
config: List[RunnableConfig],
|
||||||
|
) -> List[Union[Output, Exception]]:
|
||||||
|
results_map: Dict[int, Output] = {}
|
||||||
|
|
||||||
|
def pending(iterable: List[U]) -> List[U]:
|
||||||
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for attempt in self._async_retrying():
|
||||||
|
with attempt:
|
||||||
|
# Get the results of the inputs that have not succeeded yet.
|
||||||
|
result = await super().abatch(
|
||||||
|
pending(inputs),
|
||||||
|
self._patch_config_list(
|
||||||
|
pending(config), pending(run_manager), attempt.retry_state
|
||||||
|
),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
# Register the results of the inputs that have succeeded.
|
||||||
|
first_exception = None
|
||||||
|
for i, r in enumerate(result):
|
||||||
|
if isinstance(r, Exception):
|
||||||
|
if not first_exception:
|
||||||
|
first_exception = r
|
||||||
|
continue
|
||||||
|
results_map[i] = r
|
||||||
|
# If any exception occurred, raise it, to retry the failed ones
|
||||||
|
if first_exception:
|
||||||
|
raise first_exception
|
||||||
|
if (
|
||||||
|
attempt.retry_state.outcome
|
||||||
|
and not attempt.retry_state.outcome.failed
|
||||||
|
):
|
||||||
|
attempt.retry_state.set_result(result)
|
||||||
|
except RetryError as e:
|
||||||
|
try:
|
||||||
|
result
|
||||||
|
except UnboundLocalError:
|
||||||
|
result = cast(List[Output], [e] * len(inputs))
|
||||||
|
|
||||||
|
outputs: List[Union[Output, Exception]] = []
|
||||||
|
for idx, _ in enumerate(inputs):
|
||||||
|
if idx in results_map:
|
||||||
|
outputs.append(results_map[idx])
|
||||||
|
else:
|
||||||
|
outputs.append(result.pop(0))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> List[Output]:
|
||||||
|
return await self._abatch_with_config(
|
||||||
|
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# stream() and transform() are not retried because retrying a stream
|
||||||
|
# is not very intuitive.
|
@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -12,6 +11,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
@ -23,7 +23,11 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
coerce_to_runnable,
|
coerce_to_runnable,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig, get_config_list
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
get_config_list,
|
||||||
|
get_executor_for_config,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||||
|
|
||||||
|
|
||||||
@ -122,7 +126,7 @@ class RouterRunnable(
|
|||||||
inputs: List[RouterInput],
|
inputs: List[RouterInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
max_concurrency: Optional[int] = None,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input["key"] for input in inputs]
|
||||||
@ -130,16 +134,23 @@ class RouterRunnable(
|
|||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
runnable: Runnable, input: Input, config: RunnableConfig
|
||||||
|
) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return runnable.invoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return runnable.invoke(input, config, **kwargs)
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return list(
|
return cast(
|
||||||
executor.map(
|
List[Output],
|
||||||
lambda runnable, input, config: runnable.invoke(input, config),
|
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
||||||
runnables,
|
|
||||||
actual_inputs,
|
|
||||||
configs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
@ -147,7 +158,7 @@ class RouterRunnable(
|
|||||||
inputs: List[RouterInput],
|
inputs: List[RouterInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
max_concurrency: Optional[int] = None,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input["key"] for input in inputs]
|
||||||
@ -155,12 +166,23 @@ class RouterRunnable(
|
|||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
runnable: Runnable, input: Input, config: RunnableConfig
|
||||||
|
) -> Union[Output, Exception]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
return await runnable.ainvoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
else:
|
||||||
|
return await runnable.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
return await gather_with_concurrency(
|
return await gather_with_concurrency(
|
||||||
max_concurrency,
|
configs[0].get("max_concurrency"),
|
||||||
*(
|
*(
|
||||||
runnable.ainvoke(input, config)
|
ainvoke(runnable, input, config)
|
||||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -141,7 +141,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
else:
|
else:
|
||||||
assert call.args[2].get("tags") == ["b-tag"]
|
assert call.args[2].get("tags") == ["b-tag"]
|
||||||
assert call.args[2].get("max_concurrency") == 5
|
assert call.args[2].get("max_concurrency") == 5
|
||||||
spy_seq_step.reset_mock()
|
mocker.stop(spy_seq_step)
|
||||||
|
|
||||||
assert [
|
assert [
|
||||||
*fake.with_config(tags=["a-tag"]).stream(
|
*fake.with_config(tags=["a-tag"]).stream(
|
||||||
@ -1423,3 +1423,365 @@ def test_recursive_lambda() -> None:
|
|||||||
|
|
||||||
with pytest.raises(RecursionError):
|
with pytest.raises(RecursionError):
|
||||||
runnable.invoke(0, {"recursion_limit": 9})
|
runnable.invoke(0, {"recursion_limit": 9})
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrying(mocker: MockerFixture) -> None:
|
||||||
|
def _lambda(x: int) -> Union[int, Runnable]:
|
||||||
|
if x == 1:
|
||||||
|
raise ValueError("x is 1")
|
||||||
|
elif x == 2:
|
||||||
|
raise RuntimeError("x is 2")
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
|
runnable = RunnableLambda(_lambda_mock)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable.invoke(1)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 1
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).invoke(1)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 2 # retried
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).invoke(2)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 1 # did not retry
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).batch([1, 2, 0])
|
||||||
|
|
||||||
|
# 3rd input isn't retried because it succeeded
|
||||||
|
assert _lambda_mock.call_count == 3 + 2
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
output = runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).batch([1, 2, 0], return_exceptions=True)
|
||||||
|
|
||||||
|
# 3rd input isn't retried because it succeeded
|
||||||
|
assert _lambda_mock.call_count == 3 + 2
|
||||||
|
assert len(output) == 3
|
||||||
|
assert isinstance(output[0], ValueError)
|
||||||
|
assert isinstance(output[1], RuntimeError)
|
||||||
|
assert output[2] == 0
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||||
|
def _lambda(x: int) -> Union[int, Runnable]:
|
||||||
|
if x == 1:
|
||||||
|
raise ValueError("x is 1")
|
||||||
|
elif x == 2:
|
||||||
|
raise RuntimeError("x is 2")
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||||
|
runnable = RunnableLambda(_lambda_mock)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await runnable.ainvoke(1)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 1
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).ainvoke(1)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 2 # retried
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).ainvoke(2)
|
||||||
|
|
||||||
|
assert _lambda_mock.call_count == 1 # did not retry
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).abatch([1, 2, 0])
|
||||||
|
|
||||||
|
# 3rd input isn't retried because it succeeded
|
||||||
|
assert _lambda_mock.call_count == 3 + 2
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
output = await runnable.with_retry(
|
||||||
|
stop_after_attempt=2,
|
||||||
|
retry_if_exception_type=(ValueError,),
|
||||||
|
).abatch([1, 2, 0], return_exceptions=True)
|
||||||
|
|
||||||
|
# 3rd input isn't retried because it succeeded
|
||||||
|
assert _lambda_mock.call_count == 3 + 2
|
||||||
|
assert len(output) == 3
|
||||||
|
assert isinstance(output[0], ValueError)
|
||||||
|
assert isinstance(output[1], RuntimeError)
|
||||||
|
assert output[2] == 0
|
||||||
|
_lambda_mock.reset_mock()
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||||
|
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||||
|
def __init__(self, fail_starts_with: str) -> None:
|
||||||
|
self.fail_starts_with = fail_starts_with
|
||||||
|
|
||||||
|
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _batch(
|
||||||
|
self,
|
||||||
|
inputs: List[str],
|
||||||
|
) -> List:
|
||||||
|
outputs: List[Any] = []
|
||||||
|
for input in inputs:
|
||||||
|
if input.startswith(self.fail_starts_with):
|
||||||
|
outputs.append(ValueError())
|
||||||
|
else:
|
||||||
|
outputs.append(input + "a")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[str],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
return self._batch_with_config(
|
||||||
|
self._batch,
|
||||||
|
inputs,
|
||||||
|
config,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = (
|
||||||
|
ControlledExceptionRunnable("bux")
|
||||||
|
| ControlledExceptionRunnable("bar")
|
||||||
|
| ControlledExceptionRunnable("baz")
|
||||||
|
| ControlledExceptionRunnable("foo")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
|
||||||
|
# Test batch
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain.batch(["foo", "bar", "baz", "qux"])
|
||||||
|
|
||||||
|
spy = mocker.spy(ControlledExceptionRunnable, "batch")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
inputs = ["foo", "bar", "baz", "qux"]
|
||||||
|
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True)
|
||||||
|
assert len(outputs) == 4
|
||||||
|
assert isinstance(outputs[0], ValueError)
|
||||||
|
assert isinstance(outputs[1], ValueError)
|
||||||
|
assert isinstance(outputs[2], ValueError)
|
||||||
|
assert outputs[3] == "quxaaaa"
|
||||||
|
assert spy.call_count == 4
|
||||||
|
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
|
||||||
|
assert inputs_to_batch == [
|
||||||
|
# inputs to sequence step 0
|
||||||
|
# same as inputs to sequence.batch()
|
||||||
|
["foo", "bar", "baz", "qux"],
|
||||||
|
# inputs to sequence step 1
|
||||||
|
# == outputs of sequence step 0 as no exceptions were raised
|
||||||
|
["fooa", "bara", "baza", "quxa"],
|
||||||
|
# inputs to sequence step 2
|
||||||
|
# 'bar' was dropped as it raised an exception in step 1
|
||||||
|
["fooaa", "bazaa", "quxaa"],
|
||||||
|
# inputs to sequence step 3
|
||||||
|
# 'baz' was dropped as it raised an exception in step 2
|
||||||
|
["fooaaa", "quxaaa"],
|
||||||
|
]
|
||||||
|
parent_runs = sorted(
|
||||||
|
(r for r in tracer.runs if r.parent_run_id is None),
|
||||||
|
key=lambda run: inputs.index(run.inputs["input"]),
|
||||||
|
)
|
||||||
|
assert len(parent_runs) == 4
|
||||||
|
|
||||||
|
parent_run_foo = parent_runs[0]
|
||||||
|
assert parent_run_foo.inputs["input"] == "foo"
|
||||||
|
assert parent_run_foo.error == repr(ValueError())
|
||||||
|
assert len(parent_run_foo.child_runs) == 4
|
||||||
|
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_bar = parent_runs[1]
|
||||||
|
assert parent_run_bar.inputs["input"] == "bar"
|
||||||
|
assert parent_run_bar.error == repr(ValueError())
|
||||||
|
assert len(parent_run_bar.child_runs) == 2
|
||||||
|
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_baz = parent_runs[2]
|
||||||
|
assert parent_run_baz.inputs["input"] == "baz"
|
||||||
|
assert parent_run_baz.error == repr(ValueError())
|
||||||
|
assert len(parent_run_baz.child_runs) == 3
|
||||||
|
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_qux = parent_runs[3]
|
||||||
|
assert parent_run_qux.inputs["input"] == "qux"
|
||||||
|
assert parent_run_qux.error is None
|
||||||
|
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||||
|
assert len(parent_run_qux.child_runs) == 4
|
||||||
|
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||||
|
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||||
|
def __init__(self, fail_starts_with: str) -> None:
|
||||||
|
self.fail_starts_with = fail_starts_with
|
||||||
|
|
||||||
|
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def _abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[str],
|
||||||
|
) -> List:
|
||||||
|
outputs: List[Any] = []
|
||||||
|
for input in inputs:
|
||||||
|
if input.startswith(self.fail_starts_with):
|
||||||
|
outputs.append(ValueError())
|
||||||
|
else:
|
||||||
|
outputs.append(input + "a")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[str],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
return await self._abatch_with_config(
|
||||||
|
self._abatch,
|
||||||
|
inputs,
|
||||||
|
config,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = (
|
||||||
|
ControlledExceptionRunnable("bux")
|
||||||
|
| ControlledExceptionRunnable("bar")
|
||||||
|
| ControlledExceptionRunnable("baz")
|
||||||
|
| ControlledExceptionRunnable("foo")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(chain, RunnableSequence)
|
||||||
|
|
||||||
|
# Test abatch
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await chain.abatch(["foo", "bar", "baz", "qux"])
|
||||||
|
|
||||||
|
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
|
||||||
|
tracer = FakeTracer()
|
||||||
|
inputs = ["foo", "bar", "baz", "qux"]
|
||||||
|
outputs = await chain.abatch(
|
||||||
|
inputs, dict(callbacks=[tracer]), return_exceptions=True
|
||||||
|
)
|
||||||
|
assert len(outputs) == 4
|
||||||
|
assert isinstance(outputs[0], ValueError)
|
||||||
|
assert isinstance(outputs[1], ValueError)
|
||||||
|
assert isinstance(outputs[2], ValueError)
|
||||||
|
assert outputs[3] == "quxaaaa"
|
||||||
|
assert spy.call_count == 4
|
||||||
|
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
|
||||||
|
assert inputs_to_batch == [
|
||||||
|
# inputs to sequence step 0
|
||||||
|
# same as inputs to sequence.batch()
|
||||||
|
["foo", "bar", "baz", "qux"],
|
||||||
|
# inputs to sequence step 1
|
||||||
|
# == outputs of sequence step 0 as no exceptions were raised
|
||||||
|
["fooa", "bara", "baza", "quxa"],
|
||||||
|
# inputs to sequence step 2
|
||||||
|
# 'bar' was dropped as it raised an exception in step 1
|
||||||
|
["fooaa", "bazaa", "quxaa"],
|
||||||
|
# inputs to sequence step 3
|
||||||
|
# 'baz' was dropped as it raised an exception in step 2
|
||||||
|
["fooaaa", "quxaaa"],
|
||||||
|
]
|
||||||
|
parent_runs = sorted(
|
||||||
|
(r for r in tracer.runs if r.parent_run_id is None),
|
||||||
|
key=lambda run: inputs.index(run.inputs["input"]),
|
||||||
|
)
|
||||||
|
assert len(parent_runs) == 4
|
||||||
|
|
||||||
|
parent_run_foo = parent_runs[0]
|
||||||
|
assert parent_run_foo.inputs["input"] == "foo"
|
||||||
|
assert parent_run_foo.error == repr(ValueError())
|
||||||
|
assert len(parent_run_foo.child_runs) == 4
|
||||||
|
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_bar = parent_runs[1]
|
||||||
|
assert parent_run_bar.inputs["input"] == "bar"
|
||||||
|
assert parent_run_bar.error == repr(ValueError())
|
||||||
|
assert len(parent_run_bar.child_runs) == 2
|
||||||
|
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_baz = parent_runs[2]
|
||||||
|
assert parent_run_baz.inputs["input"] == "baz"
|
||||||
|
assert parent_run_baz.error == repr(ValueError())
|
||||||
|
assert len(parent_run_baz.child_runs) == 3
|
||||||
|
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
repr(ValueError()),
|
||||||
|
]
|
||||||
|
|
||||||
|
parent_run_qux = parent_runs[3]
|
||||||
|
assert parent_run_qux.inputs["input"] == "qux"
|
||||||
|
assert parent_run_qux.error is None
|
||||||
|
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||||
|
assert len(parent_run_qux.child_runs) == 4
|
||||||
|
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||||
|
Loading…
Reference in New Issue
Block a user