diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 72dd8a3eaa7..72975d6acc0 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1050,12 +1050,19 @@ class Runnable(Generic[Input, Output], ABC): for chunk in input: if not got_first_val: - final = chunk + final = _adapt_first_streaming_chunk(chunk) # type: ignore got_first_val = True else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final = final + chunk # type: ignore[operator] + try: + final = final + chunk # type: ignore[operator] + except TypeError: + raise TypeError( + f"Failed while trying to add together " + f"type {type(final)} and {type(chunk)}." + f"These types should be addable for transform to work." + ) if got_first_val: yield from self.stream(final, config, **kwargs) @@ -1076,12 +1083,19 @@ class Runnable(Generic[Input, Output], ABC): async for chunk in input: if not got_first_val: - final = chunk + final = _adapt_first_streaming_chunk(chunk) # type: ignore got_first_val = True else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final = final + chunk # type: ignore[operator] + try: + final = final + chunk # type: ignore[operator] + except TypeError: + raise TypeError( + f"Failed while trying to add together " + f"type {type(final)} and {type(chunk)}." + f"These types should be addable for atransform to work." + ) if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -3560,7 +3574,7 @@ class RunnableLambda(Runnable[Input, Output]): final: Optional[Input] = None for ichunk in input: if final is None: - final = ichunk + final = _adapt_first_streaming_chunk(ichunk) # type: ignore else: try: final = final + ichunk # type: ignore[operator] @@ -3644,7 +3658,7 @@ class RunnableLambda(Runnable[Input, Output]): final: Optional[Input] = None async for ichunk in input: if final is None: - final = ichunk + final = _adapt_first_streaming_chunk(ichunk) else: try: final = final + ichunk # type: ignore[operator] @@ -4445,3 +4459,11 @@ def chain( yield chunk """ return RunnableLambda(func) + + +def _adapt_first_streaming_chunk(chunk: Any) -> Any: + """This might transform the first chunk of a stream into an AddableDict.""" + if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): + return AddableDict(chunk) + else: + return chunk diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index ec03c431e32..b1b362892eb 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -165,7 +165,7 @@ class GenericFakeChatModel(BaseChatModel): streaming. """ - messages: Iterator[AIMessage] + messages: Iterator[Union[AIMessage, str]] """Get an iterator over messages. This can be expanded to accept other types like Callables / dicts / strings @@ -187,7 +187,11 @@ class GenericFakeChatModel(BaseChatModel): ) -> ChatResult: """Top Level call""" message = next(self.messages) - generation = ChatGeneration(message=message) + if isinstance(message, str): + message_ = AIMessage(content=message) + else: + message_ = message + generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) def _stream( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 696a70520b3..6daf18d17cd 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -70,6 +70,7 @@ from langchain_core.runnables import ( chain, ) from langchain_core.runnables.base import RunnableSerializable +from langchain_core.runnables.utils import Input, Output from langchain_core.tools import BaseTool, tool from langchain_core.tracers import ( BaseTracer, @@ -5183,3 +5184,70 @@ async def test_astream_log_deep_copies() -> None: "name": "add_one", "type": "chain", } + + +def test_transform_of_runnable_lambda_with_dicts() -> None: + """Test transform of runnable lamdbda.""" + runnable = RunnableLambda(lambda x: x) + chunks = iter( + [ + {"foo": "a"}, + {"foo": "n"}, + ] + ) + assert list(runnable.transform(chunks)) == [{"foo": "an"}] + + +async def test_atransform_of_runnable_lambda_with_dicts() -> None: + async def identity(x: Dict[str, str]) -> Dict[str, str]: + """Return x.""" + return x + + runnable = RunnableLambda[Dict[str, str], Dict[str, str]](identity) + + async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + yield {"foo": "a"} + yield {"foo": "n"} + + chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "an"}] + + +def test_default_transform_with_dicts() -> None: + """Test that default transform works with dicts.""" + + class CustomRunnable(RunnableSerializable[Input, Output]): + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + return cast(Output, input) # type: ignore + + runnable = CustomRunnable[Dict[str, str], Dict[str, str]]() + chunks = iter( + [ + {"foo": "a"}, + {"foo": "n"}, + ] + ) + + assert list(runnable.transform(chunks)) == [{"foo": "an"}] + + +async def test_defualt_atransform_with_dicts() -> None: + """Test that default transform works with dicts.""" + + class CustomRunnable(RunnableSerializable[Input, Output]): + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + return cast(Output, input) + + runnable = CustomRunnable[Dict[str, str], Dict[str, str]]() + + async def chunk_iterator() -> AsyncIterator[Dict[str, str]]: + yield {"foo": "a"} + yield {"foo": "n"} + + chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + + assert chunks == [{"foo": "an"}]