mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
core(minor): Implement stream and astream for RunnableBranch (#14805)
* This PR adds `stream` implementations to Runnable Branch. * Runnable Branch still does not support `transform` so it'll break streaming if it happens in middle or end of sequence, but will work if happens at beginning of sequence. * Fixes use the async callback manager for async methods * Handle BaseException rather than Exception, so more errors could be logged as errors when they are encountered --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -3981,6 +3981,140 @@ async def test_runnable_branch_abatch() -> None:
|
||||
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||
|
||||
|
||||
def test_runnable_branch_stream() -> None:
|
||||
"""Verify that stream works for RunnableBranch."""
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
branch = RunnableBranch[str, Any](
|
||||
(lambda x: x == "hello", llm),
|
||||
lambda x: x,
|
||||
)
|
||||
|
||||
assert list(branch.stream("hello")) == list(llm_res)
|
||||
assert list(branch.stream("bye")) == ["bye"]
|
||||
|
||||
|
||||
def test_runnable_branch_stream_with_callbacks() -> None:
|
||||
"""Verify that stream works for RunnableBranch when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
def raise_value_error(x: str) -> Any:
|
||||
"""Raise a value error."""
|
||||
raise ValueError(f"x is {x}")
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
branch = RunnableBranch[str, Any](
|
||||
(lambda x: x == "error", raise_value_error),
|
||||
(lambda x: x == "hello", llm),
|
||||
lambda x: x,
|
||||
)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert list(branch.stream("hello", config=config)) == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
# Verify that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
for _ in branch.stream("error", config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
assert list(branch.stream("bye", config=config)) == ["bye"]
|
||||
|
||||
assert len(tracer.runs) == 3
|
||||
assert tracer.runs[2].error is None
|
||||
assert tracer.runs[2].outputs == {"output": "bye"}
|
||||
|
||||
|
||||
async def test_runnable_branch_astream() -> None:
|
||||
"""Verify that astream works for RunnableBranch."""
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
branch = RunnableBranch[str, Any](
|
||||
(lambda x: x == "hello", llm),
|
||||
lambda x: x,
|
||||
)
|
||||
|
||||
assert [_ async for _ in branch.astream("hello")] == list(llm_res)
|
||||
assert [_ async for _ in branch.astream("bye")] == ["bye"]
|
||||
|
||||
# Verify that the async variant is used if available
|
||||
async def condition(x: str) -> bool:
|
||||
return x == "hello"
|
||||
|
||||
async def repeat(x: str) -> str:
|
||||
return x + x
|
||||
|
||||
async def reverse(x: str) -> str:
|
||||
return x[::-1]
|
||||
|
||||
branch = RunnableBranch[str, Any]((condition, repeat), llm)
|
||||
|
||||
assert [_ async for _ in branch.astream("hello")] == ["hello" * 2]
|
||||
assert [_ async for _ in branch.astream("bye")] == list(llm_res)
|
||||
|
||||
branch = RunnableBranch[str, Any]((condition, llm), reverse)
|
||||
|
||||
assert [_ async for _ in branch.astream("hello")] == list(llm_res)
|
||||
assert [_ async for _ in branch.astream("bye")] == ["eyb"]
|
||||
|
||||
|
||||
async def test_runnable_branch_astream_with_callbacks() -> None:
|
||||
"""Verify that astream works for RunnableBranch when using callbacks."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
def raise_value_error(x: str) -> Any:
|
||||
"""Raise a value error."""
|
||||
raise ValueError(f"x is {x}")
|
||||
|
||||
llm_res = "i'm a textbot"
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
branch = RunnableBranch[str, Any](
|
||||
(lambda x: x == "error", raise_value_error),
|
||||
(lambda x: x == "hello", llm),
|
||||
lambda x: x,
|
||||
)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert [_ async for _ in branch.astream("hello", config=config)] == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
assert tracer.runs[0].outputs == {"output": llm_res}
|
||||
|
||||
# Verify that the chain on error is invoked
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in branch.astream("error", config=config):
|
||||
pass
|
||||
|
||||
assert len(tracer.runs) == 2
|
||||
assert "ValueError('x is error')" in str(tracer.runs[1].error)
|
||||
assert tracer.runs[1].outputs is None
|
||||
|
||||
assert [_ async for _ in branch.astream("bye", config=config)] == ["bye"]
|
||||
|
||||
assert len(tracer.runs) == 3
|
||||
assert tracer.runs[2].error is None
|
||||
assert tracer.runs[2].outputs == {"output": "bye"}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
|
Reference in New Issue
Block a user