mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
core[patch]: pass exceptions to fallbacks (#16048)
This commit is contained in:
parent
770f57196e
commit
c5656a4905
@ -923,12 +923,17 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
fallbacks: Sequence[Runnable[Input, Output]],
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
*,
|
*,
|
||||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
||||||
|
exception_key: Optional[str] = None,
|
||||||
) -> RunnableWithFallbacksT[Input, Output]:
|
) -> RunnableWithFallbacksT[Input, Output]:
|
||||||
"""Add fallbacks to a runnable, returning a new Runnable.
|
"""Add fallbacks to a runnable, returning a new Runnable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fallbacks: A sequence of runnables to try if the original runnable fails.
|
fallbacks: A sequence of runnables to try if the original runnable fails.
|
||||||
exceptions_to_handle: A tuple of exception types to handle.
|
exceptions_to_handle: A tuple of exception types to handle.
|
||||||
|
exception_key: If string is specified then handled exceptions will be passed
|
||||||
|
to fallbacks as part of the input under the specified key. If None,
|
||||||
|
exceptions will not be passed to fallbacks. If used, the base runnable
|
||||||
|
and its fallbacks must accept a dictionary as input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new Runnable that will try the original runnable, and then each
|
A new Runnable that will try the original runnable, and then each
|
||||||
@ -940,6 +945,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
runnable=self,
|
runnable=self,
|
||||||
fallbacks=fallbacks,
|
fallbacks=fallbacks,
|
||||||
exceptions_to_handle=exceptions_to_handle,
|
exceptions_to_handle=exceptions_to_handle,
|
||||||
|
exception_key=exception_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
""" --- Helper methods for Subclasses --- """
|
""" --- Helper methods for Subclasses --- """
|
||||||
|
@ -2,6 +2,7 @@ import asyncio
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -9,6 +10,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.load.dump import dumpd
|
from langchain_core.load.dump import dumpd
|
||||||
@ -89,6 +91,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
Any exception that is not a subclass of these exceptions will be raised immediately.
|
Any exception that is not a subclass of these exceptions will be raised immediately.
|
||||||
"""
|
"""
|
||||||
|
exception_key: Optional[str] = None
|
||||||
|
"""If string is specified then handled exceptions will be passed to fallbacks as
|
||||||
|
part of the input under the specified key. If None, exceptions
|
||||||
|
will not be passed to fallbacks. If used, the base runnable and its fallbacks
|
||||||
|
must accept a dictionary as input."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -136,6 +143,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
if self.exception_key is not None and not isinstance(input, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then input must be a dictionary."
|
||||||
|
f"However found a type of {type(input)} for input"
|
||||||
|
)
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
@ -144,8 +156,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
dumpd(self), input, name=config.get("run_name")
|
dumpd(self), input, name=config.get("run_name")
|
||||||
)
|
)
|
||||||
first_error = None
|
first_error = None
|
||||||
|
last_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
try:
|
||||||
|
if self.exception_key and last_error is not None:
|
||||||
|
input[self.exception_key] = last_error
|
||||||
output = runnable.invoke(
|
output = runnable.invoke(
|
||||||
input,
|
input,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
@ -154,6 +169,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
first_error = e
|
first_error = e
|
||||||
|
last_error = e
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
@ -171,6 +187,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
|
if self.exception_key is not None and not isinstance(input, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then input must be a dictionary."
|
||||||
|
f"However found a type of {type(input)} for input"
|
||||||
|
)
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
@ -180,8 +201,11 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
first_error = None
|
first_error = None
|
||||||
|
last_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
try:
|
||||||
|
if self.exception_key and last_error is not None:
|
||||||
|
input[self.exception_key] = last_error
|
||||||
output = await runnable.ainvoke(
|
output = await runnable.ainvoke(
|
||||||
input,
|
input,
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
@ -190,6 +214,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
first_error = e
|
first_error = e
|
||||||
|
last_error = e
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
@ -211,8 +236,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
if return_exceptions:
|
if self.exception_key is not None and not all(
|
||||||
raise NotImplementedError()
|
isinstance(input, dict) for input in inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||||
|
f"However found a type of {type(inputs[0])} for input"
|
||||||
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
@ -241,35 +271,51 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
first_error = None
|
to_return: Dict[int, Any] = {}
|
||||||
|
run_again = {i: input for i, input in enumerate(inputs)}
|
||||||
|
handled_exceptions: Dict[int, BaseException] = {}
|
||||||
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
outputs = runnable.batch(
|
||||||
outputs = runnable.batch(
|
[input for _, input in sorted(run_again.items())],
|
||||||
inputs,
|
[
|
||||||
[
|
# each step a child run of the corresponding root run
|
||||||
# each step a child run of the corresponding root run
|
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||||
patch_config(config, callbacks=rm.get_child())
|
for i in sorted(run_again)
|
||||||
for rm, config in zip(run_managers, configs)
|
],
|
||||||
],
|
return_exceptions=True,
|
||||||
return_exceptions=return_exceptions,
|
**kwargs,
|
||||||
**kwargs,
|
)
|
||||||
)
|
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||||
except self.exceptions_to_handle as e:
|
if isinstance(output, BaseException) and not isinstance(
|
||||||
if first_error is None:
|
output, self.exceptions_to_handle
|
||||||
first_error = e
|
):
|
||||||
except BaseException as e:
|
if not return_exceptions:
|
||||||
for rm in run_managers:
|
first_to_raise = first_to_raise or output
|
||||||
rm.on_chain_error(e)
|
else:
|
||||||
raise e
|
handled_exceptions[i] = cast(BaseException, output)
|
||||||
else:
|
run_again.pop(i)
|
||||||
for rm, output in zip(run_managers, outputs):
|
elif isinstance(output, self.exceptions_to_handle):
|
||||||
rm.on_chain_end(output)
|
if self.exception_key:
|
||||||
return outputs
|
input[self.exception_key] = output # type: ignore
|
||||||
if first_error is None:
|
handled_exceptions[i] = cast(BaseException, output)
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
else:
|
||||||
for rm in run_managers:
|
run_managers[i].on_chain_end(output)
|
||||||
rm.on_chain_error(first_error)
|
to_return[i] = output
|
||||||
raise first_error
|
run_again.pop(i)
|
||||||
|
handled_exceptions.pop(i, None)
|
||||||
|
if first_to_raise:
|
||||||
|
raise first_to_raise
|
||||||
|
if not run_again:
|
||||||
|
break
|
||||||
|
|
||||||
|
sorted_handled_exceptions = sorted(handled_exceptions.items())
|
||||||
|
for i, error in sorted_handled_exceptions:
|
||||||
|
run_managers[i].on_chain_error(error)
|
||||||
|
if not return_exceptions and sorted_handled_exceptions:
|
||||||
|
raise sorted_handled_exceptions[0][1]
|
||||||
|
to_return.update(handled_exceptions)
|
||||||
|
return [output for _, output in sorted(to_return.items())]
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -281,8 +327,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
if return_exceptions:
|
if self.exception_key is not None and not all(
|
||||||
raise NotImplementedError()
|
isinstance(input, dict) for input in inputs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then inputs must be dictionaries."
|
||||||
|
f"However found a type of {type(inputs[0])} for input"
|
||||||
|
)
|
||||||
|
|
||||||
if not inputs:
|
if not inputs:
|
||||||
return []
|
return []
|
||||||
@ -313,33 +364,54 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
first_error = None
|
to_return = {}
|
||||||
|
run_again = {i: input for i, input in enumerate(inputs)}
|
||||||
|
handled_exceptions: Dict[int, BaseException] = {}
|
||||||
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
outputs = await runnable.abatch(
|
||||||
outputs = await runnable.abatch(
|
[input for _, input in sorted(run_again.items())],
|
||||||
inputs,
|
[
|
||||||
[
|
# each step a child run of the corresponding root run
|
||||||
# each step a child run of the corresponding root run
|
patch_config(configs[i], callbacks=run_managers[i].get_child())
|
||||||
patch_config(config, callbacks=rm.get_child())
|
for i in sorted(run_again)
|
||||||
for rm, config in zip(run_managers, configs)
|
],
|
||||||
],
|
return_exceptions=True,
|
||||||
return_exceptions=return_exceptions,
|
**kwargs,
|
||||||
**kwargs,
|
)
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
|
||||||
if first_error is None:
|
if isinstance(output, BaseException) and not isinstance(
|
||||||
first_error = e
|
output, self.exceptions_to_handle
|
||||||
except BaseException as e:
|
):
|
||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
if not return_exceptions:
|
||||||
else:
|
first_to_raise = first_to_raise or output
|
||||||
await asyncio.gather(
|
else:
|
||||||
*(
|
handled_exceptions[i] = cast(BaseException, output)
|
||||||
rm.on_chain_end(output)
|
run_again.pop(i)
|
||||||
for rm, output in zip(run_managers, outputs)
|
elif isinstance(output, self.exceptions_to_handle):
|
||||||
)
|
if self.exception_key:
|
||||||
)
|
input[self.exception_key] = output # type: ignore
|
||||||
return outputs
|
handled_exceptions[i] = cast(BaseException, output)
|
||||||
if first_error is None:
|
else:
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
to_return[i] = output
|
||||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
await run_managers[i].on_chain_end(output)
|
||||||
raise first_error
|
run_again.pop(i)
|
||||||
|
handled_exceptions.pop(i, None)
|
||||||
|
|
||||||
|
if first_to_raise:
|
||||||
|
raise first_to_raise
|
||||||
|
if not run_again:
|
||||||
|
break
|
||||||
|
|
||||||
|
sorted_handled_exceptions = sorted(handled_exceptions.items())
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
run_managers[i].on_chain_error(error)
|
||||||
|
for i, error in sorted_handled_exceptions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not return_exceptions and sorted_handled_exceptions:
|
||||||
|
raise sorted_handled_exceptions[0][1]
|
||||||
|
to_return.update(handled_exceptions)
|
||||||
|
return [output for _, output in sorted(to_return.items())] # type: ignore
|
||||||
|
@ -0,0 +1,373 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_fallbacks[chain]
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableSequence"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"first": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableParallel"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"steps": {
|
||||||
|
"buz": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"langchain_core",
|
||||||
|
"runnables",
|
||||||
|
"base",
|
||||||
|
"RunnableLambda"
|
||||||
|
],
|
||||||
|
"repr": "RunnableLambda(lambda x: x)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"middle": [],
|
||||||
|
"last": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableWithFallbacks"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"runnable": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableSequence"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"first": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"prompt",
|
||||||
|
"PromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"buz"
|
||||||
|
],
|
||||||
|
"template": "what did baz say to {buz}",
|
||||||
|
"template_format": "f-string",
|
||||||
|
"partial_variables": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"middle": [],
|
||||||
|
"last": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||||
|
},
|
||||||
|
"name": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fallbacks": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableSequence"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"first": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"prompt",
|
||||||
|
"PromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"buz"
|
||||||
|
],
|
||||||
|
"template": "what did baz say to {buz}",
|
||||||
|
"template_format": "f-string",
|
||||||
|
"partial_variables": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"middle": [],
|
||||||
|
"last": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['bar'])"
|
||||||
|
},
|
||||||
|
"name": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exceptions_to_handle": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"builtins",
|
||||||
|
"Exception"
|
||||||
|
],
|
||||||
|
"repr": "<class 'Exception'>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exception_key": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_fallbacks[chain_pass_exceptions]
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableSequence"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"first": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableParallel"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"steps": {
|
||||||
|
"text": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnablePassthrough"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"func": null,
|
||||||
|
"afunc": null,
|
||||||
|
"input_type": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"middle": [],
|
||||||
|
"last": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableWithFallbacks"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"runnable": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"langchain_core",
|
||||||
|
"runnables",
|
||||||
|
"base",
|
||||||
|
"RunnableLambda"
|
||||||
|
],
|
||||||
|
"repr": "RunnableLambda(_raise_error)"
|
||||||
|
},
|
||||||
|
"fallbacks": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"langchain_core",
|
||||||
|
"runnables",
|
||||||
|
"base",
|
||||||
|
"RunnableLambda"
|
||||||
|
],
|
||||||
|
"repr": "RunnableLambda(_dont_raise_error)"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exceptions_to_handle": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"builtins",
|
||||||
|
"Exception"
|
||||||
|
],
|
||||||
|
"repr": "<class 'Exception'>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exception_key": "exception"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"name": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_fallbacks[llm]
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableWithFallbacks"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"runnable": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||||
|
},
|
||||||
|
"fallbacks": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['bar'])"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exceptions_to_handle": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"builtins",
|
||||||
|
"Exception"
|
||||||
|
],
|
||||||
|
"repr": "<class 'Exception'>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exception_key": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_fallbacks[llm_multi]
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"schema",
|
||||||
|
"runnable",
|
||||||
|
"RunnableWithFallbacks"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"runnable": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
||||||
|
},
|
||||||
|
"fallbacks": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['baz'], i=1)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"tests",
|
||||||
|
"unit_tests",
|
||||||
|
"fake",
|
||||||
|
"llm",
|
||||||
|
"FakeListLLM"
|
||||||
|
],
|
||||||
|
"repr": "FakeListLLM(responses=['bar'])"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exceptions_to_handle": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"builtins",
|
||||||
|
"Exception"
|
||||||
|
],
|
||||||
|
"repr": "<class 'Exception'>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"exception_key": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
@ -696,280 +696,6 @@
|
|||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
# name: test_llm_with_fallbacks[llm_chain_with_fallbacks]
|
|
||||||
'''
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableSequence"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"first": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableParallel"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"steps": {
|
|
||||||
"buz": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"langchain_core",
|
|
||||||
"runnables",
|
|
||||||
"base",
|
|
||||||
"RunnableLambda"
|
|
||||||
],
|
|
||||||
"repr": "RunnableLambda(lambda x: x)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"middle": [],
|
|
||||||
"last": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableWithFallbacks"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"runnable": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableSequence"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"first": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"prompts",
|
|
||||||
"prompt",
|
|
||||||
"PromptTemplate"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"input_variables": [
|
|
||||||
"buz"
|
|
||||||
],
|
|
||||||
"template": "what did baz say to {buz}",
|
|
||||||
"template_format": "f-string",
|
|
||||||
"partial_variables": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"middle": [],
|
|
||||||
"last": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
|
||||||
},
|
|
||||||
"name": null
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fallbacks": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableSequence"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"first": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"prompts",
|
|
||||||
"prompt",
|
|
||||||
"PromptTemplate"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"input_variables": [
|
|
||||||
"buz"
|
|
||||||
],
|
|
||||||
"template": "what did baz say to {buz}",
|
|
||||||
"template_format": "f-string",
|
|
||||||
"partial_variables": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"middle": [],
|
|
||||||
"last": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['bar'])"
|
|
||||||
},
|
|
||||||
"name": null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"exceptions_to_handle": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"builtins",
|
|
||||||
"Exception"
|
|
||||||
],
|
|
||||||
"repr": "<class 'Exception'>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"name": null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
# ---
|
|
||||||
# name: test_llm_with_fallbacks[llm_with_fallbacks]
|
|
||||||
'''
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableWithFallbacks"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"runnable": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
|
||||||
},
|
|
||||||
"fallbacks": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['bar'])"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"exceptions_to_handle": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"builtins",
|
|
||||||
"Exception"
|
|
||||||
],
|
|
||||||
"repr": "<class 'Exception'>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
# ---
|
|
||||||
# name: test_llm_with_fallbacks[llm_with_multi_fallbacks]
|
|
||||||
'''
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"schema",
|
|
||||||
"runnable",
|
|
||||||
"RunnableWithFallbacks"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"runnable": {
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['foo'], i=1)"
|
|
||||||
},
|
|
||||||
"fallbacks": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['baz'], i=1)"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"tests",
|
|
||||||
"unit_tests",
|
|
||||||
"fake",
|
|
||||||
"llm",
|
|
||||||
"FakeListLLM"
|
|
||||||
],
|
|
||||||
"repr": "FakeListLLM(responses=['bar'])"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"exceptions_to_handle": [
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "not_implemented",
|
|
||||||
"id": [
|
|
||||||
"builtins",
|
|
||||||
"Exception"
|
|
||||||
],
|
|
||||||
"repr": "<class 'Exception'>"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
# ---
|
|
||||||
# name: test_prompt_with_chat_model
|
# name: test_prompt_with_chat_model
|
||||||
'''
|
'''
|
||||||
ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))])
|
ChatPromptTemplate(input_variables=['question'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))])
|
||||||
|
231
libs/core/tests/unit_tests/runnables/test_fallbacks.py
Normal file
231
libs/core/tests/unit_tests/runnables/test_fallbacks.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from langchain_core.load import dumps
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableLambda,
|
||||||
|
RunnableParallel,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableWithFallbacks,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.fake.llm import FakeListLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def llm() -> RunnableWithFallbacks:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
return error_llm.with_fallbacks([pass_llm])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def llm_multi() -> RunnableWithFallbacks:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def chain() -> Runnable:
|
||||||
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
||||||
|
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
||||||
|
[prompt | pass_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_error(inputs: dict) -> str:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
|
def _dont_raise_error(inputs: dict) -> str:
|
||||||
|
if "exception" in inputs:
|
||||||
|
return "bar"
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def chain_pass_exceptions() -> Runnable:
|
||||||
|
fallback = RunnableLambda(_dont_raise_error)
|
||||||
|
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||||
|
_raise_error
|
||||||
|
).with_fallbacks([fallback], exception_key="exception")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"runnable",
|
||||||
|
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
|
||||||
|
)
|
||||||
|
async def test_fallbacks(
|
||||||
|
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||||
|
) -> None:
|
||||||
|
runnable = request.getfixturevalue(runnable)
|
||||||
|
assert runnable.invoke("hello") == "bar"
|
||||||
|
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
|
assert list(runnable.stream("hello")) == ["bar"]
|
||||||
|
assert await runnable.ainvoke("hello") == "bar"
|
||||||
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
|
assert dumps(runnable, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _runnable(inputs: dict) -> str:
|
||||||
|
if inputs["text"] == "foo":
|
||||||
|
return "first"
|
||||||
|
if "exception" not in inputs:
|
||||||
|
raise ValueError()
|
||||||
|
if inputs["text"] == "bar":
|
||||||
|
return "second"
|
||||||
|
if isinstance(inputs["exception"], ValueError):
|
||||||
|
raise RuntimeError()
|
||||||
|
return "third"
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_potential_error(actual: list, expected: list) -> None:
|
||||||
|
for x, y in zip(actual, expected):
|
||||||
|
if isinstance(x, Exception):
|
||||||
|
assert isinstance(y, type(x))
|
||||||
|
else:
|
||||||
|
assert x == y
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_exception_key() -> None:
|
||||||
|
runnable = RunnableLambda(_runnable)
|
||||||
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
|
[runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable_with_single.invoke({"text": "baz"})
|
||||||
|
|
||||||
|
actual = runnable_with_single.invoke({"text": "bar"})
|
||||||
|
expected = "second"
|
||||||
|
_assert_potential_error([actual], [expected])
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
actual = runnable_with_double.invoke({"text": "baz"})
|
||||||
|
|
||||||
|
expected = "third"
|
||||||
|
_assert_potential_error([actual], [expected])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ainvoke_with_exception_key() -> None:
|
||||||
|
runnable = RunnableLambda(_runnable)
|
||||||
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
|
[runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await runnable_with_single.ainvoke({"text": "baz"})
|
||||||
|
|
||||||
|
actual = await runnable_with_single.ainvoke({"text": "bar"})
|
||||||
|
expected = "second"
|
||||||
|
_assert_potential_error([actual], [expected])
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
actual = await runnable_with_double.ainvoke({"text": "baz"})
|
||||||
|
expected = "third"
|
||||||
|
_assert_potential_error([actual], [expected])
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch() -> None:
|
||||||
|
runnable = RunnableLambda(_runnable)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||||
|
actual = runnable.batch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
expected = ["first", ValueError(), ValueError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
|
[runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
runnable_with_single.batch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||||
|
actual = runnable_with_single.batch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
expected = ["first", "second", RuntimeError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
actual = runnable_with_double.batch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = ["first", "second", "third"]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable],
|
||||||
|
exception_key="exception",
|
||||||
|
exceptions_to_handle=(ValueError,),
|
||||||
|
)
|
||||||
|
actual = runnable_with_double.batch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = ["first", "second", RuntimeError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch() -> None:
|
||||||
|
runnable = RunnableLambda(_runnable)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await runnable.abatch([{"text": "foo"}, {"text": "bar"}, {"text": "baz"}])
|
||||||
|
actual = await runnable.abatch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
expected = ["first", ValueError(), ValueError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_single = runnable.with_fallbacks(
|
||||||
|
[runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await runnable_with_single.abatch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}]
|
||||||
|
)
|
||||||
|
actual = await runnable_with_single.abatch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
expected = ["first", "second", RuntimeError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable], exception_key="exception"
|
||||||
|
)
|
||||||
|
actual = await runnable_with_double.abatch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = ["first", "second", "third"]
|
||||||
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
runnable_with_double = runnable.with_fallbacks(
|
||||||
|
[runnable, runnable],
|
||||||
|
exception_key="exception",
|
||||||
|
exceptions_to_handle=(ValueError,),
|
||||||
|
)
|
||||||
|
actual = await runnable_with_double.abatch(
|
||||||
|
[{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = ["first", "second", RuntimeError()]
|
||||||
|
_assert_potential_error(actual, expected)
|
@ -66,7 +66,6 @@ from langchain_core.runnables import (
|
|||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnablePick,
|
RunnablePick,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
|
||||||
add,
|
add,
|
||||||
chain,
|
chain,
|
||||||
)
|
)
|
||||||
@ -3683,52 +3682,6 @@ async def test_runnable_sequence_atransform() -> None:
|
|||||||
assert "".join(chunks) == "foo-lish"
|
assert "".join(chunks) == "foo-lish"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
|
||||||
|
|
||||||
return error_llm.with_fallbacks([pass_llm])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
|
||||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
|
||||||
|
|
||||||
return error_llm.with_fallbacks([error_llm_2, pass_llm])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def llm_chain_with_fallbacks() -> Runnable:
|
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
|
||||||
|
|
||||||
prompt = PromptTemplate.from_template("what did baz say to {buz}")
|
|
||||||
return RunnableParallel({"buz": lambda x: x}) | (prompt | error_llm).with_fallbacks(
|
|
||||||
[prompt | pass_llm]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"runnable",
|
|
||||||
["llm_with_fallbacks", "llm_with_multi_fallbacks", "llm_chain_with_fallbacks"],
|
|
||||||
)
|
|
||||||
async def test_llm_with_fallbacks(
|
|
||||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
|
||||||
) -> None:
|
|
||||||
runnable = request.getfixturevalue(runnable)
|
|
||||||
assert runnable.invoke("hello") == "bar"
|
|
||||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
|
||||||
assert list(runnable.stream("hello")) == ["bar"]
|
|
||||||
assert await runnable.ainvoke("hello") == "bar"
|
|
||||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
|
||||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
|
||||||
if sys.version_info >= (3, 9):
|
|
||||||
assert dumps(runnable, pretty=True) == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
||||||
"""Parse the output of an LLM call to a comma-separated list."""
|
"""Parse the output of an LLM call to a comma-separated list."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user