mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
core[patch]: simple fallback streaming (#16055)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
@@ -8,6 +8,7 @@ from langchain_core.load import dumps
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
@@ -229,3 +230,61 @@ async def test_abatch() -> None:
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
def _generate(input: Iterator) -> Iterator[str]:
|
||||
yield from "foo bar"
|
||||
|
||||
|
||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||
raise ValueError()
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def test_fallbacks_stream() -> None:
|
||||
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
|
||||
[RunnableGenerator(_generate)]
|
||||
)
|
||||
assert list(runnable.stream({})) == [c for c in "foo bar"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_generate)]
|
||||
)
|
||||
list(runnable.stream({}))
|
||||
|
||||
|
||||
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
for c in "foo bar":
|
||||
yield c
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
raise ValueError()
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
|
||||
|
||||
async def test_fallbacks_astream() -> None:
|
||||
runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
expected = (c for c in "foo bar")
|
||||
async for c in runnable.astream({}):
|
||||
assert c == next(expected)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
async for c in runnable.astream({}):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user