mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
Implement stream and astream for RunnableLambda (#14794)
**Description:** Implement stream and astream methods for RunnableLambda to make streaming work for functions returning Runnable - **Issue:** https://github.com/langchain-ai/langchain/issues/11998 - **Dependencies:** No new dependencies - **Twitter handle:** https://twitter.com/qtangs --------- Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
@@ -4,6 +4,8 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@@ -3541,6 +3543,130 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
|
||||
def test_runnable_lambda_stream() -> None:
|
||||
"""Test that stream works for both normal functions & those returning Runnable."""
|
||||
# Normal output should work
|
||||
output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Runnable output should also work
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = list(RunnableLambda(lambda x: llm).stream(""))
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
"""Test that stream works for RunnableLambda when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
|
||||
llm_res
|
||||
)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
async def test_runnable_lambda_astream() -> None:
|
||||
"""Test that astream works for both normal functions & those returning Runnable."""
|
||||
|
||||
# Wrapper to make a normal function async
|
||||
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
|
||||
async def afunc(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return afunc
|
||||
|
||||
# Normal output should work
|
||||
output: List[Any] = [
|
||||
chunk
|
||||
async for chunk in RunnableLambda(
|
||||
func=id,
|
||||
afunc=awrapper(range), # id func is just dummy
|
||||
).astream(5)
|
||||
]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Normal output using func should also work
|
||||
output = [_ async for _ in RunnableLambda(range).astream(5)]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Runnable output should also work
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = [
|
||||
_
|
||||
async for _ in RunnableLambda(
|
||||
func=id,
|
||||
afunc=awrapper(lambda x: llm),
|
||||
).astream("")
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
|
||||
output = [
|
||||
chunk
|
||||
async for chunk in cast(
|
||||
AsyncIterator[str], RunnableLambda(lambda x: llm).astream("")
|
||||
)
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
async def test_runnable_lambda_astream_with_callbacks() -> None:
|
||||
"""Test that astream works for RunnableLambda when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert [
|
||||
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
|
||||
] == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||
|
Reference in New Issue
Block a user