core[patch]: simple fallback streaming (#16055)

This commit is contained in:
Bagatur
2024-01-19 16:31:54 -08:00
committed by GitHub
parent 4ef0ed4ddc
commit 1e29b676d5
3 changed files with 179 additions and 2 deletions

View File

@@ -1,5 +1,5 @@
import sys
from typing import Any
from typing import Any, AsyncIterator, Iterator
import pytest
from syrupy import SnapshotAssertion
@@ -8,6 +8,7 @@ from langchain_core.load import dumps
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import (
Runnable,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
@@ -229,3 +230,61 @@ async def test_abatch() -> None:
expected = ["first", "second", RuntimeError()]
_assert_potential_error(actual, expected)
def _generate(input: Iterator) -> Iterator[str]:
yield from "foo bar"
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError()
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield ""
raise ValueError()
def test_fallbacks_stream() -> None:
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
assert list(runnable.stream({})) == [c for c in "foo bar"]
with pytest.raises(ValueError):
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
[RunnableGenerator(_generate)]
)
list(runnable.stream({}))
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
for c in "foo bar":
yield c
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError()
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield ""
raise ValueError()
async def test_fallbacks_astream() -> None:
runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
expected = (c for c in "foo bar")
async for c in runnable.astream({}):
assert c == next(expected)
with pytest.raises(ValueError):
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
async for c in runnable.astream({}):
pass