From a2cc9b55ba58e273c2f4303450534607652762f5 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 23 Apr 2024 10:35:06 -0400 Subject: [PATCH] core[patch]: Remove autoupgrade to addable dict in Runnable/RunnableLambda/RunnablePassthrough transform (#20677) Causes an issue for this code ```python from langchain.chat_models.openai import ChatOpenAI from langchain.output_parsers.openai_tools import JsonOutputToolsParser from langchain.schema import SystemMessage prompt = SystemMessage(content="You are a nice assistant.") + "{question}" llm = ChatOpenAI( model_kwargs={ "tools": [ { "type": "function", "function": { "name": "web_search", "description": "Searches the web for the answer to the question.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The question to search for.", }, }, }, }, } ], }, streaming=True, ) parser = JsonOutputToolsParser(first_tool_only=True) llm_chain = prompt | llm | parser | (lambda x: x) for chunk in llm_chain.stream({"question": "tell me more about turtles"}): print(chunk) # message = llm_chain.invoke({"question": "tell me more about turtles"}) # print(message) ``` Instead by definition, we'll assume that RunnableLambdas consume the entire stream and that if the stream isn't addable then it's the last message of the stream that's in the usable format. --- If users want to use addable dicts, they can wrap the dict in an AddableDict class. --- Likely, need to follow up with the same change for other places in the code that do the upgrade --- libs/core/langchain_core/runnables/base.py | 67 +++++++++++-------- .../langchain_core/runnables/passthrough.py | 41 ++++++++---- libs/core/langchain_core/runnables/utils.py | 8 --- .../unit_tests/runnables/test_runnable.py | 34 ++++++++-- 4 files changed, 98 insertions(+), 52 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a29a0677ee9..ee48234df3a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -69,7 +69,6 @@ from langchain_core.runnables.utils import ( accepts_config, accepts_context, accepts_run_manager, - adapt_first_streaming_chunk, create_model, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -1280,21 +1279,22 @@ class Runnable(Generic[Input, Output], ABC): final: Input got_first_val = False - for chunk in input: + for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for transform to work." - ) + final = ichunk if got_first_val: yield from self.stream(final, config, **kwargs) @@ -1313,21 +1313,22 @@ class Runnable(Generic[Input, Output], ABC): final: Input got_first_val = False - async for chunk in input: + async for ichunk in input: + # The default implementation of transform is to buffer input and + # then call stream. + # It'll attempt to gather all input into a single chunk using + # the `+` operator. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk, + # and we'll iterate until we get to the last chunk. if not got_first_val: - final = adapt_first_streaming_chunk(chunk) # type: ignore + final = ichunk got_first_val = True else: - # Make a best effort to gather, for any type that supports `+` - # This method should throw an error if gathering fails. try: - final = final + chunk # type: ignore[operator] + final = final + ichunk # type: ignore[operator] except TypeError: - raise TypeError( - f"Failed while trying to add together " - f"type {type(final)} and {type(chunk)}." - f"These types should be addable for atransform to work." - ) + final = ichunk if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -3998,10 +3999,16 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Iterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) # type: ignore + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + if not got_first_val: + final = ichunk + got_first_val = True else: try: final = final + ichunk # type: ignore[operator] @@ -4082,10 +4089,16 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> AsyncIterator[Output]: - final: Optional[Input] = None + final: Input + got_first_val = False async for ichunk in input: - if final is None: - final = adapt_first_streaming_chunk(ichunk) + # By definitions, RunnableLambdas consume all input before emitting output. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + # So we'll iterate until we get to the last chunk! + if not got_first_val: + final = ichunk + got_first_val = True else: try: final = final + ichunk # type: ignore[operator] diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index d2fbf30e4b9..ec081aea97f 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -40,7 +40,6 @@ from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, ConfigurableFieldSpec, - adapt_first_streaming_chunk, create_model, ) from langchain_core.utils.aiter import atee, py_anext @@ -243,16 +242,22 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): for chunk in self._transform_stream_with_config(input, identity, config): yield chunk else: - final = None + final: Other + got_first_chunk = False for chunk in self._transform_stream_with_config(input, identity, config): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) - else: - final = final + chunk - if final is not None: + if not got_first_chunk: + final = chunk + got_first_chunk = True + else: + try: + final = final + chunk # type: ignore[operator] + except TypeError: + final = chunk + + if got_first_chunk: call_func_with_variable_args( self.func, final, ensure_config(config), **kwargs ) @@ -269,18 +274,28 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ): yield chunk else: - final = None + got_first_chunk = False async for chunk in self._atransform_stream_with_config( input, identity, config ): yield chunk - if final is None: - final = adapt_first_streaming_chunk(chunk) - else: - final = final + chunk - if final is not None: + # By definitions, a function will operate on the aggregated + # input. So we'll aggregate the input until we get to the last + # chunk. + # If the input is not addable, then we'll assume that we can + # only operate on the last chunk. + if not got_first_chunk: + final = chunk + got_first_chunk = True + else: + try: + final = final + chunk # type: ignore[operator] + except TypeError: + final = chunk + + if got_first_chunk: config = ensure_config(config) if self.afunc is not None: await acall_func_with_variable_args( diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index dff10ad0495..d5553e786f5 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -524,11 +524,3 @@ def _create_model_cached( return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions ) - - -def adapt_first_streaming_chunk(chunk: Any) -> Any: - """This might transform the first chunk of a stream into an AddableDict.""" - if isinstance(chunk, dict) and not isinstance(chunk, AddableDict): - return AddableDict(chunk) - else: - return chunk diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index ca6d2a3adaf..07ebd37abf2 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5401,11 +5401,21 @@ def test_transform_of_runnable_lambda_with_dicts() -> None: runnable = RunnableLambda(lambda x: x) chunks = iter( [ - {"foo": "a"}, {"foo": "n"}, ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] + + # Test as part of a sequence + seq = runnable | runnable + chunks = iter( + [ + {"foo": "n"}, + ] + ) + assert list(seq.transform(chunks)) == [{"foo": "n"}] + # Test some other edge cases + assert list(seq.stream({"foo": "n"})) == [{"foo": "n"}] async def test_atransform_of_runnable_lambda_with_dicts() -> None: @@ -5420,7 +5430,11 @@ async def test_atransform_of_runnable_lambda_with_dicts() -> None: yield {"foo": "n"} chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] - assert chunks == [{"foo": "an"}] + assert chunks == [{"foo": "n"}] + + seq = runnable | runnable + chunks = [chunk async for chunk in seq.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] def test_default_transform_with_dicts() -> None: @@ -5440,7 +5454,8 @@ def test_default_transform_with_dicts() -> None: ] ) - assert list(runnable.transform(chunks)) == [{"foo": "an"}] + assert list(runnable.transform(chunks)) == [{"foo": "n"}] + assert list(runnable.stream({"foo": "n"})) == [{"foo": "n"}] async def test_default_atransform_with_dicts() -> None: @@ -5460,6 +5475,17 @@ async def test_default_atransform_with_dicts() -> None: chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())] + assert chunks == [{"foo": "n"}] + + # Test with addable dict + async def chunk_iterator_with_addable() -> AsyncIterator[Dict[str, str]]: + yield AddableDict({"foo": "a"}) + yield AddableDict({"foo": "n"}) + + chunks = [ + chunk async for chunk in runnable.atransform(chunk_iterator_with_addable()) + ] + assert chunks == [{"foo": "an"}]