diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 51ccc58f9fb..71a80b20ed3 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1231,11 +1231,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): else: raise first_exception - def stream( + def _transform( self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + input: Iterator[Input], + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, ) -> Iterator[Output]: # setup callbacks config = ensure_config(config) @@ -1254,37 +1254,50 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): else: break - # invoke the first steps - try: - for step in steps[0:streaming_start_index]: - input = step.invoke( - input, - # mark each step as a child run + final_pipeline = None + gathered_input = None + if streaming_start_index == 0: + final_pipeline = steps[streaming_start_index].transform( + input, + patch_config(config, callbacks=run_manager.get_child("seq:step:1")), + ) + else: + try: + for input_chunk in input: + if gathered_input is None: + gathered_input = input_chunk + else: + gathered_input += input_chunk + # invoke the first steps + for step in steps[0:streaming_start_index]: + gathered_input = step.invoke( + gathered_input, + # mark each step as a child run + patch_config( + config, + callbacks=run_manager.get_child( + f"seq:step:{steps.index(step)+1}" + ), + ), + ) + # stream the first of the last steps with the final non-streaming input + final_pipeline = steps[streaming_start_index].stream( + gathered_input, patch_config( config, callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" + f"seq:step:{streaming_start_index+1}" ), ), ) - except (KeyboardInterrupt, Exception) as e: - run_manager.on_chain_error(e) - raise + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise # stream the last steps final: Union[Output, None] = None final_supported = True try: - # stream the first of the last steps with non-streaming input - final_pipeline = steps[streaming_start_index].stream( - input, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{streaming_start_index+1}" - ), - ), - ) # stream the rest of the last steps with streaming input for step in steps[streaming_start_index + 1 :]: final_pipeline = step.transform( @@ -1296,6 +1309,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ), ), ) + for output in final_pipeline: yield output # Accumulate output if possible, otherwise disable accumulation @@ -1316,11 +1330,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): else: run_manager.on_chain_end(final) - async def astream( + async def _atransform( self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + input: AsyncIterator[Input], + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, ) -> AsyncIterator[Output]: # setup callbacks config = ensure_config(config) @@ -1334,42 +1348,55 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): streaming_start_index = len(steps) - 1 for i in range(len(steps) - 1, 0, -1): - if type(steps[i]).transform != Runnable.transform: + if type(steps[i]).atransform != Runnable.atransform: streaming_start_index = i - 1 else: break - # invoke the first steps - try: - for step in steps[0:streaming_start_index]: - input = await step.ainvoke( - input, - # mark each step as a child run + final_pipeline = None + gathered_input = None + if streaming_start_index == 0: + final_pipeline = steps[0].atransform( + input, + patch_config(config, callbacks=run_manager.get_child("seq:step:1")), + ) + else: + try: + async for input_chunk in input: + if gathered_input is None: + gathered_input = input_chunk + else: + gathered_input += input_chunk + # invoke the first steps + for step in steps[0:streaming_start_index]: + gathered_input = await step.ainvoke( + gathered_input, + # mark each step as a child run + patch_config( + config, + callbacks=run_manager.get_child( + f"seq:step:{steps.index(step)+1}" + ), + ), + ) + # stream the first of the last steps with the final non-streaming input + final_pipeline = steps[streaming_start_index].astream( + gathered_input, patch_config( config, callbacks=run_manager.get_child( - f"seq:step:{steps.index(step)+1}" + f"seq:step:{streaming_start_index+1}" ), ), ) - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_chain_error(e) - raise + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise # stream the last steps final: Union[Output, None] = None final_supported = True try: - # stream the first of the last steps with non-streaming input - final_pipeline = steps[streaming_start_index].astream( - input, - patch_config( - config, - callbacks=run_manager.get_child( - f"seq:step:{streaming_start_index+1}" - ), - ), - ) # stream the rest of the last steps with streaming input for step in steps[streaming_start_index + 1 :]: final_pipeline = step.atransform( @@ -1401,6 +1428,47 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): else: await run_manager.on_chain_end(final) + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + yield from self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + yield from self.transform(iter([input]), config, **kwargs) + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async for chunk in self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ): + yield chunk + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + async for chunk in self.atransform(input_aiter(), config, **kwargs): + yield chunk + class RunnableMapChunk(Dict[str, Any]): """ diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 98bf284fd05..9b1275e95bf 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1315,6 +1315,37 @@ async def test_deep_astream() -> None: assert "".join(chunks) == "foo-lish" +def test_runnable_sequence_transform() -> None: + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain = llm | StrOutputParser() + + stream = chain.transform(llm.stream("Hi there!")) + + chunks = [] + for chunk in stream: + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert "".join(chunks) == "foo-lish" + + +@pytest.mark.asyncio +async def test_runnable_sequence_atransform() -> None: + llm = FakeStreamingListLLM(responses=["foo-lish"]) + + chain = llm | StrOutputParser() + + stream = chain.atransform(llm.astream("Hi there!")) + + chunks = [] + async for chunk in stream: + chunks.append(chunk) + + assert len(chunks) == len("foo-lish") + assert "".join(chunks) == "foo-lish" + + @pytest.fixture() def llm_with_fallbacks() -> RunnableWithFallbacks: error_llm = FakeListLLM(responses=["foo"], i=1)