diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 04e9dce5843..e0f27aee9dc 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2775,6 +2775,155 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs, ) + def _transform( + self, + input: Iterator[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + **kwargs: Any, + ) -> Iterator[Output]: + final: Optional[Input] = None + for ichunk in input: + if final is None: + final = ichunk + else: + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk + + output = call_func_with_variable_args( + self.func, cast(Input, final), config, run_manager, **kwargs + ) + + # If the output is a runnable, use its stream output + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking " + f"{self} with input {final}." + ) + for chunk in output.stream( + final, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ): + yield chunk + else: + # Otherwise, just yield it + yield output + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + if hasattr(self, "func"): + for output in self._transform_stream_with_config( + input, + self._transform, + self._config(config, self.func), + **kwargs, + ): + yield output + else: + raise TypeError( + "Cannot stream a coroutine function synchronously." + "Use `astream` instead." + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + return self.transform(iter([input]), config, **kwargs) + + async def _atransform( + self, + input: AsyncIterator[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> AsyncIterator[Output]: + final: Optional[Input] = None + async for ichunk in input: + if final is None: + final = ichunk + else: + try: + final = final + ichunk # type: ignore[operator] + except TypeError: + final = ichunk + + if hasattr(self, "afunc"): + afunc = self.afunc + else: + + @wraps(self.func) + async def f(*args, **kwargs): # type: ignore[no-untyped-def] + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.func, **kwargs), *args + ) + + afunc = f + + output = await acall_func_with_variable_args( + afunc, cast(Input, final), config, run_manager + ) + + # If the output is a runnable, use its astream output + if isinstance(output, Runnable): + recursion_limit = config["recursion_limit"] + if recursion_limit <= 0: + raise RecursionError( + f"Recursion limit reached when invoking " + f"{self} with input {final}." + ) + async for chunk in output.astream( + final, + patch_config( + config, + callbacks=run_manager.get_child(), + recursion_limit=recursion_limit - 1, + ), + ): + yield chunk + else: + # Otherwise, just yield it + yield output + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for output in self._atransform_stream_with_config( + input, + self._atransform, + self._config(config, self.afunc if hasattr(self, "afunc") else self.func), + **kwargs, + ): + yield output + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk + class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): """ diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index a67b91efac8..4130ed31901 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -4,6 +4,8 @@ from operator import itemgetter from typing import ( Any, AsyncIterator, + Awaitable, + Callable, Dict, Iterator, List, @@ -3541,6 +3543,130 @@ async def test_async_retrying(mocker: MockerFixture) -> None: _lambda_mock.reset_mock() +def test_runnable_lambda_stream() -> None: + """Test that stream works for both normal functions & those returning Runnable.""" + # Normal output should work + output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)] + assert output == [range(5)] + + # Runnable output should also work + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + output = list(RunnableLambda(lambda x: llm).stream("")) + assert output == list(llm_res) + + +def test_runnable_lambda_stream_with_callbacks() -> None: + """Test that stream works for RunnableLambda when using callbacks.""" + tracer = FakeTracer() + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + config: RunnableConfig = {"callbacks": [tracer]} + + assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list( + llm_res + ) + + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": llm_res} + + def raise_value_error(x: int) -> int: + """Raise a value error.""" + raise ValueError("x is too large") + + # Check that the chain on error is invoked + with pytest.raises(ValueError): + for _ in RunnableLambda(raise_value_error).stream(1000, config=config): + pass + + assert len(tracer.runs) == 2 + assert "ValueError('x is too large')" in str(tracer.runs[1].error) + assert tracer.runs[1].outputs is None + + +async def test_runnable_lambda_astream() -> None: + """Test that astream works for both normal functions & those returning Runnable.""" + + # Wrapper to make a normal function async + def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]: + async def afunc(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return afunc + + # Normal output should work + output: List[Any] = [ + chunk + async for chunk in RunnableLambda( + func=id, + afunc=awrapper(range), # id func is just dummy + ).astream(5) + ] + assert output == [range(5)] + + # Normal output using func should also work + output = [_ async for _ in RunnableLambda(range).astream(5)] + assert output == [range(5)] + + # Runnable output should also work + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + output = [ + _ + async for _ in RunnableLambda( + func=id, + afunc=awrapper(lambda x: llm), + ).astream("") + ] + assert output == list(llm_res) + + output = [ + chunk + async for chunk in cast( + AsyncIterator[str], RunnableLambda(lambda x: llm).astream("") + ) + ] + assert output == list(llm_res) + + +async def test_runnable_lambda_astream_with_callbacks() -> None: + """Test that astream works for RunnableLambda when using callbacks.""" + tracer = FakeTracer() + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + config: RunnableConfig = {"callbacks": [tracer]} + + assert [ + _ async for _ in RunnableLambda(lambda x: llm).astream("", config=config) + ] == list(llm_res) + + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": llm_res} + + def raise_value_error(x: int) -> int: + """Raise a value error.""" + raise ValueError("x is too large") + + # Check that the chain on error is invoked + with pytest.raises(ValueError): + async for _ in RunnableLambda(raise_value_error).astream(1000, config=config): + pass + + assert len(tracer.runs) == 2 + assert "ValueError('x is too large')" in str(tracer.runs[1].error) + assert tracer.runs[1].outputs is None + + @freeze_time("2023-01-01") def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: class ControlledExceptionRunnable(Runnable[str, str]):