Implement stream and astream for RunnableLambda (#14794)

**Description:** Implement stream and astream methods for RunnableLambda
to make streaming work for functions returning Runnable
  - **Issue:** https://github.com/langchain-ai/langchain/issues/11998
  - **Dependencies:** No new dependencies
  - **Twitter handle:** https://twitter.com/qtangs

---------

Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
Quy Tang 2023-12-27 04:49:02 +08:00 committed by GitHub
parent 7e26559256
commit 7ef25a3c1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 275 additions and 0 deletions

View File

@ -2775,6 +2775,155 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs,
)
def _transform(
self,
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
final: Optional[Input] = None
for ichunk in input:
if final is None:
final = ichunk
else:
try:
final = final + ichunk # type: ignore[operator]
except TypeError:
final = ichunk
output = call_func_with_variable_args(
self.func, cast(Input, final), config, run_manager, **kwargs
)
# If the output is a runnable, use its stream output
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking "
f"{self} with input {final}."
)
for chunk in output.stream(
final,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
):
yield chunk
else:
# Otherwise, just yield it
yield output
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
if hasattr(self, "func"):
for output in self._transform_stream_with_config(
input,
self._transform,
self._config(config, self.func),
**kwargs,
):
yield output
else:
raise TypeError(
"Cannot stream a coroutine function synchronously."
"Use `astream` instead."
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
async def _atransform(
self,
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[Output]:
final: Optional[Input] = None
async for ichunk in input:
if final is None:
final = ichunk
else:
try:
final = final + ichunk # type: ignore[operator]
except TypeError:
final = ichunk
if hasattr(self, "afunc"):
afunc = self.afunc
else:
@wraps(self.func)
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.func, **kwargs), *args
)
afunc = f
output = await acall_func_with_variable_args(
afunc, cast(Input, final), config, run_manager
)
# If the output is a runnable, use its astream output
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking "
f"{self} with input {final}."
)
async for chunk in output.astream(
final,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
):
yield chunk
else:
# Otherwise, just yield it
yield output
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for output in self._atransform_stream_with_config(
input,
self._atransform,
self._config(config, self.afunc if hasattr(self, "afunc") else self.func),
**kwargs,
):
yield output
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
"""

View File

@ -4,6 +4,8 @@ from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
List,
@ -3541,6 +3543,130 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
_lambda_mock.reset_mock()
def test_runnable_lambda_stream() -> None:
"""Test that stream works for both normal functions & those returning Runnable."""
# Normal output should work
output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)]
assert output == [range(5)]
# Runnable output should also work
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda x: llm).stream(""))
assert output == list(llm_res)
def test_runnable_lambda_stream_with_callbacks() -> None:
"""Test that stream works for RunnableLambda when using callbacks."""
tracer = FakeTracer()
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
llm_res
)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
# Check that the chain on error is invoked
with pytest.raises(ValueError):
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
pass
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None
async def test_runnable_lambda_astream() -> None:
"""Test that astream works for both normal functions & those returning Runnable."""
# Wrapper to make a normal function async
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
async def afunc(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
return afunc
# Normal output should work
output: List[Any] = [
chunk
async for chunk in RunnableLambda(
func=id,
afunc=awrapper(range), # id func is just dummy
).astream(5)
]
assert output == [range(5)]
# Normal output using func should also work
output = [_ async for _ in RunnableLambda(range).astream(5)]
assert output == [range(5)]
# Runnable output should also work
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = [
_
async for _ in RunnableLambda(
func=id,
afunc=awrapper(lambda x: llm),
).astream("")
]
assert output == list(llm_res)
output = [
chunk
async for chunk in cast(
AsyncIterator[str], RunnableLambda(lambda x: llm).astream("")
)
]
assert output == list(llm_res)
async def test_runnable_lambda_astream_with_callbacks() -> None:
"""Test that astream works for RunnableLambda when using callbacks."""
tracer = FakeTracer()
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert [
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
] == list(llm_res)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
# Check that the chain on error is invoked
with pytest.raises(ValueError):
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
pass
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None
@freeze_time("2023-01-01")
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):