diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a29a0677ee9..ee48234df3a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -69,7 +69,6 @@ from langchain_core.runnables.utils import ( accepts_config, accepts_context, accepts_run_manager, - adapt_first_streaming_chunk, create_model, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -1280,21 +1279,22 @@ class Runnable(Generic[Input, Output], ABC): final: Input got_first_val = False - for chunk in input: + for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for transform to work." - ) + final = ichunk if got_first_val: yield from self.stream(final, config, **kwargs) @@ -1313,21 +1313,22 @@ class Runnable(Generic[Input, Output], ABC): final: Input got_first_val = False - async for chunk in input: + async for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for atransform to work." - ) + final = ichunk if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -3998,10 +3999,16 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) # type: ignore + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + if not got_first_val: + final = ichunk + got_first_val = True else: try: final = final + ichunk # type: ignore[operator] @@ -4082,10 +4089,16 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False async for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + if not got_first_val: + final = ichunk + got_first_val = True else: try: final = final + ichunk # type: ignore[operator] diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index d2fbf30e4b9..ec081aea97f 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -40,7 +40,6 @@ from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, ConfigurableFieldSpec, - adapt_first_streaming_chunk, create_model, ) from langchain_core.utils.aiter import atee, py_anext @@ -243,16 +242,22 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): for chunk in self._transform_stream_with_config(input, identity, config): yield chunk else: - final = None + final: Other + got_first_chunk = False for chunk in self._transform_stream_with_config(input, identity, config): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) - else: - final = final + chunk - if final is not None: + if not got_first_chunk: + final = chunk + got_first_chunk = True + else: + try: + final = final + chunk # type: ignore[operator] + except TypeError: + final = chunk + + if got_first_chunk: call_func_with_variable_args( self.func, final, ensure_config(config), **kwargs ) @@ -269,18 +274,28 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ): yield chunk else: - final = None + got_first_chunk = False async for chunk in self._atransform_stream_with_config( input, identity, config ): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) - else: - final = final + chunk - if final is not None: + # By definitions, a function will operate on the aggregated + # input. So we'll aggregate the input until we get to the last + # chunk. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + if not got_first_chunk: + final = chunk + got_first_chunk = True + else: + try: + final = final + chunk # type: ignore[operator] + except TypeError: + final = chunk + + if got_first_chunk: config = ensure_config(config) if self.afunc is not None: await acall_func_with_variable_args( diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index dff10ad0495..d5553e786f5 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -524,11 +524,3 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) - - -def adapt_first_streaming_chunk(chunk: Any) -> Any: - """This might transform the first chunk of a stream into an AddableDict.""" - if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): - return AddableDict(chunk) - else: - return chunk diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index ca6d2a3adaf..07ebd37abf2 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5401,11 +5401,21 @@ def test_transform_of_runnable_lambda_with_dicts() -> None: runnable = RunnableLambda(lambda x: x) chunks = iter( [ - {"foo": "a"}, {"foo": "n"}, ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] + + # Test as part of a sequence + seq = runnable | runnable + chunks = iter( + [ + {"foo": "n"}, + ] + ) + assert list(seq.transform(chunks)) == [{"foo": "n"}] + # Test some other edge cases + assert list(seq.stream({"foo": "n"})) == [{"foo": "n"}] async def test_atransform_of_runnable_lambda_with_dicts() -> None: @@ -5420,7 +5430,11 @@ async def test_atransform_of_runnable_lambda_with_dicts() -> None: yield {"foo": "n"} chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] - assert chunks == [{"foo": "an"}] + assert chunks == [{"foo": "n"}] + + seq = runnable | runnable + chunks = [chunk async for chunk in seq.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] def test_default_transform_with_dicts() -> None: @@ -5440,7 +5454,8 @@ def test_default_transform_with_dicts() -> None: ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] + assert list(runnable.stream({"foo": "n"})) == [{"foo": "n"}] async def test_default_atransform_with_dicts() -> None: @@ -5460,6 +5475,17 @@ async def test_default_atransform_with_dicts() -> None: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] + + # Test with addable dict + async def chunk_iterator_with_addable() -> AsyncIterator[Dict[str, str]]: + yield AddableDict({"foo": "a"}) + yield AddableDict({"foo": "n"}) + + chunks = [ + chunk async for chunk in runnable.atransform(chunk_iterator_with_addable()) + ] + assert chunks == [{"foo": "an"}]