diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index e41ef5877d7..c47018f05bd 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -69,6 +69,7 @@ from langchain_core.runnables.utils import ( accepts_config, accepts_context, accepts_run_manager, + adapt_first_streaming_chunk, create_model, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -1207,7 +1208,7 @@ class Runnable(Generic[Input, Output], ABC): for chunk in input: if not got_first_val: - final = _adapt_first_streaming_chunk(chunk) # type: ignore + final = adapt_first_streaming_chunk(chunk) # type: ignore got_first_val = True else: # Make a best effort to gather, for any type that supports `+` @@ -1240,7 +1241,7 @@ class Runnable(Generic[Input, Output], ABC): async for chunk in input: if not got_first_val: - final = _adapt_first_streaming_chunk(chunk) # type: ignore + final = adapt_first_streaming_chunk(chunk) # type: ignore got_first_val = True else: # Make a best effort to gather, for any type that supports `+` @@ -3731,7 +3732,7 @@ class RunnableLambda(Runnable[Input, Output]): final: Optional[Input] = None for ichunk in input: if final is None: - final = _adapt_first_streaming_chunk(ichunk) # type: ignore + final = adapt_first_streaming_chunk(ichunk) # type: ignore else: try: final = final + ichunk # type: ignore[operator] @@ -3815,7 +3816,7 @@ class RunnableLambda(Runnable[Input, Output]): final: Optional[Input] = None async for ichunk in input: if final is None: - final = _adapt_first_streaming_chunk(ichunk) + final = adapt_first_streaming_chunk(ichunk) else: try: final = final + ichunk # type: ignore[operator] @@ -4727,11 +4728,3 @@ def chain( yield chunk """ return RunnableLambda(func) - - -def _adapt_first_streaming_chunk(chunk: Any) -> Any: - """This might transform the first chunk of a stream into an AddableDict.""" - if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): - return AddableDict(chunk) - else: - return chunk diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 0f171795114..f3d4e5f78d5 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -40,6 +40,7 @@ from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, ConfigurableFieldSpec, + adapt_first_streaming_chunk, create_model, ) from langchain_core.utils.aiter import atee, py_anext @@ -248,7 +249,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): for chunk in self._transform_stream_with_config(input, identity, config): yield chunk if final is None: - final = chunk + final = adapt_first_streaming_chunk(chunk) else: final = final + chunk @@ -276,7 +277,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ): yield chunk if final is None: - final = chunk + final = adapt_first_streaming_chunk(chunk) else: final = final + chunk diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 8ff332173cc..40f52c18167 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -521,3 +521,11 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) + + +def adapt_first_streaming_chunk(chunk: Any) -> Any: + """This might transform the first chunk of a stream into an AddableDict.""" + if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): + return AddableDict(chunk) + else: + return chunk diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index fc63abd9af7..91cde649e77 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5324,7 +5324,7 @@ def test_default_transform_with_dicts() -> None: assert list(runnable.transform(chunks)) == [{"foo": "an"}] -async def test_defualt_atransform_with_dicts() -> None: +async def test_default_atransform_with_dicts() -> None: """Test that default transform works with dicts.""" class CustomRunnable(RunnableSerializable[Input, Output]): @@ -5342,3 +5342,22 @@ async def test_defualt_atransform_with_dicts() -> None: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] assert chunks == [{"foo": "an"}] + + +def test_passthrough_transform_with_dicts() -> None: + """Test that default transform works with dicts.""" + runnable = RunnablePassthrough(lambda x: x) + chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))] + assert chunks == [{"foo": "a"}, {"foo": "n"}] + + +async def test_passthrough_atransform_with_dicts() -> None: + """Test that default transform works with dicts.""" + runnable = RunnablePassthrough(lambda x: x) + + async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + yield {"foo": "a"} + yield {"foo": "n"} + + chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "a"}, {"foo": "n"}]