diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index f6c94fa4be2..a74f72abedf 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -46,7 +46,7 @@ from langchain_core.runnables.config import ( ) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent -from langchain_core.runnables.utils import Input, Output +from langchain_core.runnables.utils import Addable from langchain_core.tools import tool from langchain_core.utils.aiter import aclosing from tests.unit_tests.runnables.test_runnable_events_v1 import ( @@ -2120,19 +2120,19 @@ async def test_sync_in_sync_lambdas() -> None: _assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS) -class StreamingRunnable(Runnable[Input, Output]): +class StreamingRunnable(Runnable[Any, Addable]): """A custom runnable used for testing purposes.""" - iterable: Iterable[Any] + iterable: Iterable[Addable] - def __init__(self, iterable: Iterable[Any]) -> None: + def __init__(self, iterable: Iterable[Addable]) -> None: """Initialize the runnable.""" self.iterable = iterable @override def invoke( - self, input: Input, config: RunnableConfig | None = None, **kwargs: Any - ) -> Output: + self, input: Any, config: RunnableConfig | None = None, **kwargs: Any + ) -> Addable: """Invoke the runnable.""" msg = "Server side error" raise ValueError(msg) @@ -2140,19 +2140,19 @@ class StreamingRunnable(Runnable[Input, Output]): @override def stream( self, - input: Input, + input: Any, config: RunnableConfig | None = None, **kwargs: Any | None, - ) -> Iterator[Output]: + ) -> Iterator[Addable]: raise NotImplementedError @override async def astream( self, - input: Input, + input: Any, config: RunnableConfig | None = None, **kwargs: Any | None, - ) -> AsyncIterator[Output]: + ) -> AsyncIterator[Addable]: config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( @@ -2187,7 +2187,7 @@ class StreamingRunnable(Runnable[Input, Output]): async def test_astream_events_from_custom_runnable() -> None: """Test astream events from a custom runnable.""" iterator = ["1", "2", "3"] - runnable: Runnable[int, str] = StreamingRunnable(iterator) + runnable = StreamingRunnable(iterator) chunks = [chunk async for chunk in runnable.astream(1, version="v2")] assert chunks == ["1", "2", "3"] events = await _collect_events(runnable.astream_events(1, version="v2")) @@ -2390,7 +2390,7 @@ async def test_runnable_generator() -> None: yield "1" yield "2" - runnable: Runnable[str, str] = RunnableGenerator(transform=generator) + runnable = RunnableGenerator(transform=generator) events = await _collect_events(runnable.astream_events("hello", version="v2")) _assert_events_equal_allow_superset_metadata( events,