core[patch]: Remove autoupgrade to addable dict in Runnable/RunnableLambda/RunnablePassthrough transform (#20677)

Causes an issue for this code

```python
from langchain.chat_models.openai import ChatOpenAI
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain.schema import SystemMessage

prompt = SystemMessage(content="You are a nice assistant.") + "{question}"

llm = ChatOpenAI(
    model_kwargs={
        "tools": [
            {
                "type": "function",
                "function": {
                    "name": "web_search",
                    "description": "Searches the web for the answer to the question.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "The question to search for.",
                            },
                        },
                    },
                },
            }
        ],
    },
    streaming=True,
)

parser = JsonOutputToolsParser(first_tool_only=True)

llm_chain = prompt | llm | parser | (lambda x: x)


for chunk in llm_chain.stream({"question": "tell me more about turtles"}):
    print(chunk)

# message = llm_chain.invoke({"question": "tell me more about turtles"})

# print(message)
```

Instead by definition, we'll assume that RunnableLambdas consume the
entire stream and that if the stream isn't addable then it's the last
message of the stream that's in the usable format.

---

If users want to use addable dicts, they can wrap the dict in an
AddableDict class.

---

Likely, need to follow up with the same change for other places in the
code that do the upgrade
This commit is contained in:
Eugene Yurtsev 2024-04-23 10:35:06 -04:00 committed by GitHub
parent 9428923bab
commit a2cc9b55ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 52 deletions

View File

@ -69,7 +69,6 @@ from langchain_core.runnables.utils import (
accepts_config,
accepts_context,
accepts_run_manager,
adapt_first_streaming_chunk,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
@ -1280,21 +1279,22 @@ class Runnable(Generic[Input, Output], ABC):
final: Input
got_first_val = False
for chunk in input:
for ichunk in input:
# The default implementation of transform is to buffer input and
# then call stream.
# It'll attempt to gather all input into a single chunk using
# the `+` operator.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk,
# and we'll iterate until we get to the last chunk.
if not got_first_val:
final = adapt_first_streaming_chunk(chunk) # type: ignore
final = ichunk
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
try:
final = final + chunk # type: ignore[operator]
final = final + ichunk # type: ignore[operator]
except TypeError:
raise TypeError(
f"Failed while trying to add together "
f"type {type(final)} and {type(chunk)}."
f"These types should be addable for transform to work."
)
final = ichunk
if got_first_val:
yield from self.stream(final, config, **kwargs)
@ -1313,21 +1313,22 @@ class Runnable(Generic[Input, Output], ABC):
final: Input
got_first_val = False
async for chunk in input:
async for ichunk in input:
# The default implementation of transform is to buffer input and
# then call stream.
# It'll attempt to gather all input into a single chunk using
# the `+` operator.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk,
# and we'll iterate until we get to the last chunk.
if not got_first_val:
final = adapt_first_streaming_chunk(chunk) # type: ignore
final = ichunk
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
try:
final = final + chunk # type: ignore[operator]
final = final + ichunk # type: ignore[operator]
except TypeError:
raise TypeError(
f"Failed while trying to add together "
f"type {type(final)} and {type(chunk)}."
f"These types should be addable for atransform to work."
)
final = ichunk
if got_first_val:
async for output in self.astream(final, config, **kwargs):
@ -3998,10 +3999,16 @@ class RunnableLambda(Runnable[Input, Output]):
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
final: Optional[Input] = None
final: Input
got_first_val = False
for ichunk in input:
if final is None:
final = adapt_first_streaming_chunk(ichunk) # type: ignore
# By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk.
# So we'll iterate until we get to the last chunk!
if not got_first_val:
final = ichunk
got_first_val = True
else:
try:
final = final + ichunk # type: ignore[operator]
@ -4082,10 +4089,16 @@ class RunnableLambda(Runnable[Input, Output]):
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]:
final: Optional[Input] = None
final: Input
got_first_val = False
async for ichunk in input:
if final is None:
final = adapt_first_streaming_chunk(ichunk)
# By definitions, RunnableLambdas consume all input before emitting output.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk.
# So we'll iterate until we get to the last chunk!
if not got_first_val:
final = ichunk
got_first_val = True
else:
try:
final = final + ichunk # type: ignore[operator]

View File

@ -40,7 +40,6 @@ from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
adapt_first_streaming_chunk,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
@ -243,16 +242,22 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
else:
final = None
final: Other
got_first_chunk = False
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
if final is None:
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk
if final is not None:
if not got_first_chunk:
final = chunk
got_first_chunk = True
else:
try:
final = final + chunk # type: ignore[operator]
except TypeError:
final = chunk
if got_first_chunk:
call_func_with_variable_args(
self.func, final, ensure_config(config), **kwargs
)
@ -269,18 +274,28 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
):
yield chunk
else:
final = None
got_first_chunk = False
async for chunk in self._atransform_stream_with_config(
input, identity, config
):
yield chunk
if final is None:
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk
if final is not None:
# By definitions, a function will operate on the aggregated
# input. So we'll aggregate the input until we get to the last
# chunk.
# If the input is not addable, then we'll assume that we can
# only operate on the last chunk.
if not got_first_chunk:
final = chunk
got_first_chunk = True
else:
try:
final = final + chunk # type: ignore[operator]
except TypeError:
final = chunk
if got_first_chunk:
config = ensure_config(config)
if self.afunc is not None:
await acall_func_with_variable_args(

View File

@ -524,11 +524,3 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)
def adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk

View File

@ -5401,11 +5401,21 @@ def test_transform_of_runnable_lambda_with_dicts() -> None:
runnable = RunnableLambda(lambda x: x)
chunks = iter(
[
{"foo": "a"},
{"foo": "n"},
]
)
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
assert list(runnable.transform(chunks)) == [{"foo": "n"}]
# Test as part of a sequence
seq = runnable | runnable
chunks = iter(
[
{"foo": "n"},
]
)
assert list(seq.transform(chunks)) == [{"foo": "n"}]
# Test some other edge cases
assert list(seq.stream({"foo": "n"})) == [{"foo": "n"}]
async def test_atransform_of_runnable_lambda_with_dicts() -> None:
@ -5420,7 +5430,11 @@ async def test_atransform_of_runnable_lambda_with_dicts() -> None:
yield {"foo": "n"}
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "an"}]
assert chunks == [{"foo": "n"}]
seq = runnable | runnable
chunks = [chunk async for chunk in seq.atransform(chunk_iterator())]
assert chunks == [{"foo": "n"}]
def test_default_transform_with_dicts() -> None:
@ -5440,7 +5454,8 @@ def test_default_transform_with_dicts() -> None:
]
)
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
assert list(runnable.transform(chunks)) == [{"foo": "n"}]
assert list(runnable.stream({"foo": "n"})) == [{"foo": "n"}]
async def test_default_atransform_with_dicts() -> None:
@ -5460,6 +5475,17 @@ async def test_default_atransform_with_dicts() -> None:
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "n"}]
# Test with addable dict
async def chunk_iterator_with_addable() -> AsyncIterator[Dict[str, str]]:
yield AddableDict({"foo": "a"})
yield AddableDict({"foo": "n"})
chunks = [
chunk async for chunk in runnable.atransform(chunk_iterator_with_addable())
]
assert chunks == [{"foo": "an"}]