mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
Add RunnableGenerator
This commit is contained in:
parent
ca5293bf54
commit
b67db8deaa
@ -453,6 +453,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
@ -465,7 +466,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
output = call_func_with_variable_args(func, input, run_manager, config)
|
||||
output = call_func_with_variable_args(
|
||||
func, input, run_manager, config, **kwargs
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
@ -486,6 +489,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
@ -499,7 +503,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
try:
|
||||
output = await acall_func_with_variable_args(
|
||||
func, input, run_manager, config
|
||||
func, input, run_manager, config, **kwargs
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
@ -526,6 +530,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
@ -546,7 +551,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
]
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = [
|
||||
patch_config(c, callbacks=rm.get_child())
|
||||
@ -597,6 +601,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
@ -619,7 +624,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
)
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = [
|
||||
patch_config(c, callbacks=rm.get_child())
|
||||
@ -668,6 +672,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||
Output values, with callbacks.
|
||||
@ -689,7 +694,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(transformer):
|
||||
kwargs["config"] = patch_config(
|
||||
config, callbacks=run_manager.get_child()
|
||||
@ -746,6 +750,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
],
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||
Iterator of Output values, with callbacks.
|
||||
@ -767,7 +772,6 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(transformer):
|
||||
kwargs["config"] = patch_config(
|
||||
config, callbacks=run_manager.get_child()
|
||||
@ -2061,6 +2065,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableGenerator(Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that runs a generator function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transform: Union[
|
||||
Callable[[Iterator[Input]], Iterator[Output]],
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
],
|
||||
atransform: Optional[
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
|
||||
] = None,
|
||||
) -> None:
|
||||
if atransform is not None:
|
||||
self._atransform = atransform
|
||||
|
||||
if inspect.isasyncgenfunction(transform):
|
||||
self._atransform = transform
|
||||
elif inspect.isgeneratorfunction(transform):
|
||||
self._transform = transform
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a generator function type for `transform`."
|
||||
f"Instead got an unsupported type: {type(transform)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||
return first_param.annotation
|
||||
else:
|
||||
return Any
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
return (
|
||||
sig.return_annotation
|
||||
if sig.return_annotation != inspect.Signature.empty
|
||||
else Any
|
||||
)
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableGenerator):
|
||||
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
||||
return self._transform == other._transform
|
||||
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
|
||||
return self._atransform == other._atransform
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RunnableGenerator(...)"
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Output]:
|
||||
return self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Output]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
final = None
|
||||
for output in self.stream(input, config, **kwargs):
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
final += output
|
||||
return final
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Output]:
|
||||
if not hasattr(self, "_atransform"):
|
||||
raise NotImplementedError("This runnable does not support async methods.")
|
||||
|
||||
return self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Output]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
yield input
|
||||
|
||||
return self.atransform(input_aiter(), config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
final = None
|
||||
async for output in self.astream(input, config):
|
||||
if final is None:
|
||||
final = output
|
||||
else:
|
||||
final += output
|
||||
return final
|
||||
|
||||
|
||||
class RunnableLambda(Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that runs a callable.
|
||||
@ -2538,6 +2675,8 @@ RunnableLike = Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Callable[[Iterator[Input]], Iterator[Output]],
|
||||
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
||||
Mapping[str, Any],
|
||||
]
|
||||
|
||||
@ -2545,6 +2684,8 @@ RunnableLike = Union[
|
||||
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
|
||||
return RunnableGenerator(thing)
|
||||
elif callable(thing):
|
||||
return RunnableLambda(thing)
|
||||
elif isinstance(thing, dict):
|
||||
|
@ -152,9 +152,9 @@ def call_func_with_variable_args(
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_run_manager(func):
|
||||
@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_run_manager(func):
|
||||
|
Loading…
Reference in New Issue
Block a user