mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 19:15:44 +00:00
core[patch]: simple fallback streaming (#16055)
This commit is contained in:
parent
4ef0ed4ddc
commit
1e29b676d5
@ -302,7 +302,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -2,6 +2,8 @@ import asyncio
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -30,6 +32,7 @@ from langchain_core.runnables.utils import (
|
|||||||
Output,
|
Output,
|
||||||
get_unique_config_specs,
|
get_unique_config_specs,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.aiter import py_anext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||||
@ -415,3 +418,118 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
raise sorted_handled_exceptions[0][1]
|
raise sorted_handled_exceptions[0][1]
|
||||||
to_return.update(handled_exceptions)
|
to_return.update(handled_exceptions)
|
||||||
return [output for _, output in sorted(to_return.items())] # type: ignore
|
return [output for _, output in sorted(to_return.items())] # type: ignore
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
""""""
|
||||||
|
if self.exception_key is not None and not isinstance(input, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then input must be a dictionary."
|
||||||
|
f"However found a type of {type(input)} for input"
|
||||||
|
)
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
first_error = None
|
||||||
|
last_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
if self.exception_key and last_error is not None:
|
||||||
|
input[self.exception_key] = last_error
|
||||||
|
stream = runnable.stream(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
chunk = next(stream)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
first_error = e if first_error is None else first_error
|
||||||
|
last_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
first_error = None
|
||||||
|
break
|
||||||
|
if first_error:
|
||||||
|
run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
output: Optional[Output] = chunk
|
||||||
|
try:
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
output = None
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
run_manager.on_chain_end(output)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
if self.exception_key is not None and not isinstance(input, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'exception_key' is specified then input must be a dictionary."
|
||||||
|
f"However found a type of {type(input)} for input"
|
||||||
|
)
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
first_error = None
|
||||||
|
last_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
if self.exception_key and last_error is not None:
|
||||||
|
input[self.exception_key] = last_error
|
||||||
|
stream = runnable.astream(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
chunk = await cast(Awaitable[Output], py_anext(stream))
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
first_error = e if first_error is None else first_error
|
||||||
|
last_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
first_error = None
|
||||||
|
break
|
||||||
|
if first_error:
|
||||||
|
await run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
output: Optional[Output] = chunk
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore
|
||||||
|
except TypeError:
|
||||||
|
output = None
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any, AsyncIterator, Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
@ -8,6 +8,7 @@ from langchain_core.load import dumps
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
|
RunnableGenerator,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
@ -229,3 +230,61 @@ async def test_abatch() -> None:
|
|||||||
|
|
||||||
expected = ["first", "second", RuntimeError()]
|
expected = ["first", "second", RuntimeError()]
|
||||||
_assert_potential_error(actual, expected)
|
_assert_potential_error(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(input: Iterator) -> Iterator[str]:
|
||||||
|
yield from "foo bar"
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||||
|
raise ValueError()
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||||
|
yield ""
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallbacks_stream() -> None:
|
||||||
|
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
|
||||||
|
[RunnableGenerator(_generate)]
|
||||||
|
)
|
||||||
|
assert list(runnable.stream({})) == [c for c in "foo bar"]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
||||||
|
[RunnableGenerator(_generate)]
|
||||||
|
)
|
||||||
|
list(runnable.stream({}))
|
||||||
|
|
||||||
|
|
||||||
|
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
|
for c in "foo bar":
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
|
raise ValueError()
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
|
||||||
|
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||||
|
yield ""
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_fallbacks_astream() -> None:
|
||||||
|
runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks(
|
||||||
|
[RunnableGenerator(_agenerate)]
|
||||||
|
)
|
||||||
|
expected = (c for c in "foo bar")
|
||||||
|
async for c in runnable.astream({}):
|
||||||
|
assert c == next(expected)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||||
|
[RunnableGenerator(_agenerate)]
|
||||||
|
)
|
||||||
|
async for c in runnable.astream({}):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user