Add RunnableGenerator

This commit is contained in:
Nuno Campos 2023-09-29 11:44:07 +01:00
parent ca5293bf54
commit b67db8deaa
2 changed files with 149 additions and 8 deletions

View File

@ -453,6 +453,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -465,7 +466,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
output = call_func_with_variable_args(func, input, run_manager, config) output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
@ -486,6 +489,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
@ -499,7 +503,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
output = await acall_func_with_variable_args( output = await acall_func_with_variable_args(
func, input, run_manager, config func, input, run_manager, config, **kwargs
) )
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -526,6 +530,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -546,7 +551,6 @@ class Runnable(Generic[Input, Output], ABC):
) )
] ]
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = [ kwargs["config"] = [
patch_config(c, callbacks=rm.get_child()) patch_config(c, callbacks=rm.get_child())
@ -597,6 +601,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -619,7 +624,6 @@ class Runnable(Generic[Input, Output], ABC):
) )
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = [ kwargs["config"] = [
patch_config(c, callbacks=rm.get_child()) patch_config(c, callbacks=rm.get_child())
@ -668,6 +672,7 @@ class Runnable(Generic[Input, Output], ABC):
], ],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of """Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks. Output values, with callbacks.
@ -689,7 +694,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child() config, callbacks=run_manager.get_child()
@ -746,6 +750,7 @@ class Runnable(Generic[Input, Output], ABC):
], ],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async """Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks. Iterator of Output values, with callbacks.
@ -767,7 +772,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child() config, callbacks=run_manager.get_child()
@ -2061,6 +2065,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
yield chunk yield chunk
class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""
def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any
@property
def OutputType(self) -> Type[Output]:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False
def __repr__(self) -> str:
return "RunnableGenerator(...)"
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final += output
return final
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")
return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
return self.atransform(input_aiter(), config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
final = None
async for output in self.astream(input, config):
if final is None:
final = output
else:
final += output
return final
class RunnableLambda(Runnable[Input, Output]): class RunnableLambda(Runnable[Input, Output]):
""" """
A runnable that runs a callable. A runnable that runs a callable.
@ -2538,6 +2675,8 @@ RunnableLike = Union[
Runnable[Input, Output], Runnable[Input, Output],
Callable[[Input], Output], Callable[[Input], Output],
Callable[[Input], Awaitable[Output]], Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any], Mapping[str, Any],
] ]
@ -2545,6 +2684,8 @@ RunnableLike = Union[
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable): if isinstance(thing, Runnable):
return thing return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing): elif callable(thing):
return RunnableLambda(thing) return RunnableLambda(thing)
elif isinstance(thing, dict): elif isinstance(thing, dict):

View File

@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input, input: Input,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any,
) -> Output: ) -> Output:
"""Call function that may optionally accept a run_manager and/or config.""" """Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func): if accepts_run_manager(func):
@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input, input: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any,
) -> Output: ) -> Output:
"""Call function that may optionally accept a run_manager and/or config.""" """Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func): if accepts_run_manager(func):