Implement stream and astream for RunnableLambda (#14794)

**Description:** Implement stream and astream methods for RunnableLambda
to make streaming work for functions returning Runnable
  - **Issue:** https://github.com/langchain-ai/langchain/issues/11998
  - **Dependencies:** No new dependencies
  - **Twitter handle:** https://twitter.com/qtangs

---------

Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
Quy Tang
2023-12-27 04:49:02 +08:00
committed by GitHub
parent 7e26559256
commit 7ef25a3c1b
2 changed files with 275 additions and 0 deletions

View File

@@ -4,6 +4,8 @@ from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
List,
@@ -3541,6 +3543,130 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
_lambda_mock.reset_mock()
def test_runnable_lambda_stream() -> None:
"""Test that stream works for both normal functions & those returning Runnable."""
# Normal output should work
output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)]
assert output == [range(5)]
# Runnable output should also work
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda x: llm).stream(""))
assert output == list(llm_res)
def test_runnable_lambda_stream_with_callbacks() -> None:
"""Test that stream works for RunnableLambda when using callbacks."""
tracer = FakeTracer()
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
llm_res
)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
# Check that the chain on error is invoked
with pytest.raises(ValueError):
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
pass
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None
async def test_runnable_lambda_astream() -> None:
"""Test that astream works for both normal functions & those returning Runnable."""
# Wrapper to make a normal function async
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
async def afunc(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
return afunc
# Normal output should work
output: List[Any] = [
chunk
async for chunk in RunnableLambda(
func=id,
afunc=awrapper(range), # id func is just dummy
).astream(5)
]
assert output == [range(5)]
# Normal output using func should also work
output = [_ async for _ in RunnableLambda(range).astream(5)]
assert output == [range(5)]
# Runnable output should also work
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = [
_
async for _ in RunnableLambda(
func=id,
afunc=awrapper(lambda x: llm),
).astream("")
]
assert output == list(llm_res)
output = [
chunk
async for chunk in cast(
AsyncIterator[str], RunnableLambda(lambda x: llm).astream("")
)
]
assert output == list(llm_res)
async def test_runnable_lambda_astream_with_callbacks() -> None:
"""Test that astream works for RunnableLambda when using callbacks."""
tracer = FakeTracer()
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert [
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
] == list(llm_res)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": llm_res}
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
# Check that the chain on error is invoked
with pytest.raises(ValueError):
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
pass
assert len(tracer.runs) == 2
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
assert tracer.runs[1].outputs is None
@freeze_time("2023-01-01")
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):