make runnable dir (#9016)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Bagatur 2023-08-10 00:56:37 -07:00 committed by GitHub
parent c7a489ae0d
commit 434a96415b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 590 additions and 235 deletions

View File

@ -0,0 +1,24 @@
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableLambda,
RunnableMap,
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableBinding",
"RunnableConfig",
"RunnableMap",
"RunnableLambda",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
]

View File

@ -9,7 +9,6 @@ from typing import (
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
Coroutine,
Dict, Dict,
Generic, Generic,
Iterator, Iterator,
@ -19,7 +18,6 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
TypedDict,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -27,48 +25,15 @@ from typing import (
from pydantic import Field from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import (
gather_with_concurrency,
)
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output") Output = TypeVar("Output")
@ -87,7 +52,7 @@ class Runnable(Generic[Input, Output], ABC):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
], ],
) -> RunnableSequence[Input, Other]: ) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other)) return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__( def __ror__(
self, self,
@ -97,7 +62,7 @@ class Runnable(Generic[Input, Output], ABC):
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
], ],
) -> RunnableSequence[Other, Output]: ) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self) return RunnableSequence(first=coerce_to_runnable(other), last=self)
""" --- Public API --- """ """ --- Public API --- """
@ -150,7 +115,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = self._get_config_list(config, len(inputs)) configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs) coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros) return await gather_with_concurrency(max_concurrency, *coros)
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
@ -478,6 +443,14 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@property @property
def runnables(self) -> Iterator[Runnable[Input, Output]]: def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable yield self.runnable
@ -506,7 +479,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try: try:
output = runnable.invoke( output = runnable.invoke(
input, input,
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -550,7 +523,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try: try:
output = await runnable.ainvoke( output = await runnable.ainvoke(
input, input,
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
@ -606,7 +579,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
_patch_config(config, rm.get_child()) patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
@ -673,7 +646,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
_patch_config(config, rm.get_child()) patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
@ -716,6 +689,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -737,7 +714,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
return RunnableSequence( return RunnableSequence(
first=self.first, first=self.first,
middle=self.middle + [self.last], middle=self.middle + [self.last],
last=_coerce_to_runnable(other), last=coerce_to_runnable(other),
) )
def __ror__( def __ror__(
@ -756,7 +733,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
else: else:
return RunnableSequence( return RunnableSequence(
first=_coerce_to_runnable(other), first=coerce_to_runnable(other),
middle=[self.first] + self.middle, middle=[self.first] + self.middle,
last=self.last, last=self.last,
) )
@ -786,7 +763,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -825,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -875,7 +852,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
_patch_config(config, rm.get_child()) patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
@ -934,7 +911,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
_patch_config(config, rm.get_child()) patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
@ -990,7 +967,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -1002,12 +979,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream( final_pipeline = steps[streaming_start_index].stream(
input, _patch_config(config, run_manager.get_child()) input, patch_config(config, run_manager.get_child())
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform( final_pipeline = step.transform(
final_pipeline, _patch_config(config, run_manager.get_child()) final_pipeline, patch_config(config, run_manager.get_child())
) )
for output in final_pipeline: for output in final_pipeline:
yield output yield output
@ -1067,7 +1044,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -1079,12 +1056,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream( final_pipeline = steps[streaming_start_index].astream(
input, _patch_config(config, run_manager.get_child()) input, patch_config(config, run_manager.get_child())
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform( final_pipeline = step.atransform(
final_pipeline, _patch_config(config, run_manager.get_child()) final_pipeline, patch_config(config, run_manager.get_child())
) )
async for output in final_pipeline: async for output in final_pipeline:
yield output yield output
@ -1128,14 +1105,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
], ],
], ],
) -> None: ) -> None:
super().__init__( super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
)
@property @property
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -1168,7 +1147,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.invoke, step.invoke,
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
for step in steps.values() for step in steps.values()
] ]
@ -1211,7 +1190,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke( step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), patch_config(config, run_manager.get_child()),
) )
for step in steps.values() for step in steps.values()
) )
@ -1250,19 +1229,6 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(self.func, input, config) return self._call_with_config(self.func, input, config)
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]): class RunnableBinding(Serializable, Runnable[Input, Output]):
""" """
A runnable that delegates calls to another runnable with a set of kwargs. A runnable that delegates calls to another runnable with a set of kwargs.
@ -1279,6 +1245,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
@ -1335,160 +1305,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item yield item
class RouterInput(TypedDict): def patch_config(
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await _gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig: ) -> RunnableConfig:
config = config.copy() config = config.copy()
@ -1496,7 +1313,7 @@ def _patch_config(
return config return config
def _coerce_to_runnable( def coerce_to_runnable(
thing: Union[ thing: Union[
Runnable[Input, Output], Runnable[Input, Output],
Callable[[Input], Output], Callable[[Input], Output],
@ -1508,7 +1325,7 @@ def _coerce_to_runnable(
elif callable(thing): elif callable(thing):
return RunnableLambda(thing) return RunnableLambda(thing)
elif isinstance(thing, dict): elif isinstance(thing, dict):
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()} runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables)) return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else: else:
raise TypeError( raise TypeError(

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Any, Dict, List, TypedDict
from langchain.callbacks.base import Callbacks
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import List, Optional
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)

View File

@ -0,0 +1,184 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Generic,
Iterator,
List,
Mapping,
Optional,
TypedDict,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import (
Input,
Other,
Output,
Runnable,
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import gather_with_concurrency
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output

View File

@ -0,0 +1,18 @@
from __future__ import annotations
import asyncio
from typing import Any, Coroutine, Union
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))

File diff suppressed because one or more lines are too long

View File

@ -839,7 +839,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_with_fallbacks( async def test_llm_with_fallbacks(
runnable: RunnableWithFallbacks, request: Any runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None: ) -> None:
runnable = request.getfixturevalue(runnable) runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar" assert runnable.invoke("hello") == "bar"
@ -848,3 +848,4 @@ async def test_llm_with_fallbacks(
assert await runnable.ainvoke("hello") == "bar" assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar") assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot