mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
Implement stream and astream for RunnableLambda (#14794)
**Description:** Implement stream and astream methods for RunnableLambda to make streaming work for functions returning Runnable - **Issue:** https://github.com/langchain-ai/langchain/issues/11998 - **Dependencies:** No new dependencies - **Twitter handle:** https://twitter.com/qtangs --------- Co-authored-by: Nuno Campos <nuno@langchain.dev>
This commit is contained in:
parent
7e26559256
commit
7ef25a3c1b
@ -2775,6 +2775,155 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
final: Optional[Input] = None
|
||||
for ichunk in input:
|
||||
if final is None:
|
||||
final = ichunk
|
||||
else:
|
||||
try:
|
||||
final = final + ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = ichunk
|
||||
|
||||
output = call_func_with_variable_args(
|
||||
self.func, cast(Input, final), config, run_manager, **kwargs
|
||||
)
|
||||
|
||||
# If the output is a runnable, use its stream output
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking "
|
||||
f"{self} with input {final}."
|
||||
)
|
||||
for chunk in output.stream(
|
||||
final,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Otherwise, just yield it
|
||||
yield output
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
if hasattr(self, "func"):
|
||||
for output in self._transform_stream_with_config(
|
||||
input,
|
||||
self._transform,
|
||||
self._config(config, self.func),
|
||||
**kwargs,
|
||||
):
|
||||
yield output
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot stream a coroutine function synchronously."
|
||||
"Use `astream` instead."
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
return self.transform(iter([input]), config, **kwargs)
|
||||
|
||||
async def _atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> AsyncIterator[Output]:
|
||||
final: Optional[Input] = None
|
||||
async for ichunk in input:
|
||||
if final is None:
|
||||
final = ichunk
|
||||
else:
|
||||
try:
|
||||
final = final + ichunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
final = ichunk
|
||||
|
||||
if hasattr(self, "afunc"):
|
||||
afunc = self.afunc
|
||||
else:
|
||||
|
||||
@wraps(self.func)
|
||||
async def f(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.func, **kwargs), *args
|
||||
)
|
||||
|
||||
afunc = f
|
||||
|
||||
output = await acall_func_with_variable_args(
|
||||
afunc, cast(Input, final), config, run_manager
|
||||
)
|
||||
|
||||
# If the output is a runnable, use its astream output
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking "
|
||||
f"{self} with input {final}."
|
||||
)
|
||||
async for chunk in output.astream(
|
||||
final,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# Otherwise, just yield it
|
||||
yield output
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[Input],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for output in self._atransform_stream_with_config(
|
||||
input,
|
||||
self._atransform,
|
||||
self._config(config, self.afunc if hasattr(self, "afunc") else self.func),
|
||||
**kwargs,
|
||||
):
|
||||
yield output
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async def input_aiter() -> AsyncIterator[Input]:
|
||||
yield input
|
||||
|
||||
async for chunk in self.atransform(input_aiter(), config, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
"""
|
||||
|
@ -4,6 +4,8 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@ -3541,6 +3543,130 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
|
||||
def test_runnable_lambda_stream() -> None:
|
||||
"""Test that stream works for both normal functions & those returning Runnable."""
|
||||
# Normal output should work
|
||||
output: List[Any] = [chunk for chunk in RunnableLambda(range).stream(5)]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Runnable output should also work
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = list(RunnableLambda(lambda x: llm).stream(""))
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
"""Test that stream works for RunnableLambda when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert list(RunnableLambda(lambda x: llm).stream("", config=config)) == list(
|
||||
llm_res
|
||||
)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
for _ in RunnableLambda(raise_value_error).stream(1000, config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
async def test_runnable_lambda_astream() -> None:
|
||||
"""Test that astream works for both normal functions & those returning Runnable."""
|
||||
|
||||
# Wrapper to make a normal function async
|
||||
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
|
||||
async def afunc(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return afunc
|
||||
|
||||
# Normal output should work
|
||||
output: List[Any] = [
|
||||
chunk
|
||||
async for chunk in RunnableLambda(
|
||||
func=id,
|
||||
afunc=awrapper(range), # id func is just dummy
|
||||
).astream(5)
|
||||
]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Normal output using func should also work
|
||||
output = [_ async for _ in RunnableLambda(range).astream(5)]
|
||||
assert output == [range(5)]
|
||||
|
||||
# Runnable output should also work
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = [
|
||||
_
|
||||
async for _ in RunnableLambda(
|
||||
func=id,
|
||||
afunc=awrapper(lambda x: llm),
|
||||
).astream("")
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
|
||||
output = [
|
||||
chunk
|
||||
async for chunk in cast(
|
||||
AsyncIterator[str], RunnableLambda(lambda x: llm).astream("")
|
||||
)
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
async def test_runnable_lambda_astream_with_callbacks() -> None:
|
||||
"""Test that astream works for RunnableLambda when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert [
|
||||
_ async for _ in RunnableLambda(lambda x: llm).astream("", config=config)
|
||||
] == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
def raise_value_error(x: int) -> int:
|
||||
"""Raise a value error."""
|
||||
raise ValueError("x is too large")
|
||||
|
||||
# Check that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in RunnableLambda(raise_value_error).astream(1000, config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is too large')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||
|
Loading…
Reference in New Issue
Block a user