diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 558495173bc..9b69a9f7114 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: yield from self.stream(final, config, **kwargs) @@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -710,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -720,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore except TypeError: final_input = None final_input_supported = False @@ -788,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -798,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore[operator] except TypeError: final_input = None final_input_supported = False @@ -1315,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -1335,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -1755,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = other[key] elif other[key] is not None: - chunk[key] += other[key] + chunk[key] = chunk[key] + other[key] return chunk def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: @@ -1764,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = self[key] elif self[key] is not None: - chunk[key] += self[key] + chunk[key] = chunk[key] + self[key] return chunk @@ -2107,7 +2111,7 @@ class RunnableGenerator(Runnable[Input, Output]): return Any @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> Any: func = getattr(self, "_transform", None) or getattr(self, "_atransform") try: sig = inspect.signature(func) @@ -2137,7 +2141,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Iterator[Input], config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> Iterator[Output]: return self._transform_stream_with_config( input, self._transform, config, **kwargs @@ -2147,7 +2151,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> Iterator[Output]: return self.transform(iter([input]), config, **kwargs) @@ -2159,14 +2163,14 @@ class RunnableGenerator(Runnable[Input, Output]): if final is None: final = output else: - final += output - return final + final = final + output + return cast(Output, final) def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> AsyncIterator[Output]: if not hasattr(self, "_atransform"): raise NotImplementedError("This runnable does not support async methods.") @@ -2179,7 +2183,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> AsyncIterator[Output]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -2187,15 +2191,15 @@ class RunnableGenerator(Runnable[Input, Output]): return self.atransform(input_aiter(), config, **kwargs) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: final = None - async for output in self.astream(input, config): + async for output in self.astream(input, config, **kwargs): if final is None: final = output else: - final += output - return final + final = final + output + return cast(Output, final) class RunnableLambda(Runnable[Input, Output]): @@ -2687,7 +2691,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): return RunnableGenerator(thing) elif callable(thing): - return RunnableLambda(thing) + return RunnableLambda(cast(Callable[[Input], Output], thing)) elif isinstance(thing, dict): runnables: Mapping[str, Runnable[Any, Any]] = { key: coerce_to_runnable(r) for key, r in thing.items() diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 6a43e61d69d..f697c0328c9 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -15,14 +15,7 @@ from typing import ( from typing_extensions import TypedDict from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import ( - Input, - Other, - Output, - Runnable, - RunnableSequence, - coerce_to_runnable, -) +from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable from langchain.schema.runnable.config import ( RunnableConfig, get_config_list, @@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def __or__( - self, - other: Union[ - Runnable[Any, Other], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[RouterInput, Other]: - return RunnableSequence(first=self, last=coerce_to_runnable(other)) - - def __ror__( - self, - other: Union[ - Runnable[Other, Any], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[Other, Output]: - return RunnableSequence(first=coerce_to_runnable(other), last=self) - def invoke( self, input: RouterInput, config: Optional[RunnableConfig] = None ) -> Output: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 316a9ecad66..4a63f92ff2a 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -12,7 +12,6 @@ from typing import ( cast, ) from uuid import UUID -from langchain.schema.runnable.base import RunnableGenerator import pytest from freezegun import freeze_time @@ -57,6 +56,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from langchain.schema.runnable.base import RunnableGenerator from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -2876,7 +2876,7 @@ async def test_runnable_gen_transform() -> None: async for i in input: yield i + 1 - chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one + chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one assert chain.input_schema.schema() == {