chore(core): improve types for StreamingRunnable (#34540)

This commit is contained in:
Christophe Bornet
2026-01-10 03:34:50 +01:00
committed by GitHub
parent 5e9765d811
commit 2a2a4067ca

View File

@@ -46,7 +46,7 @@ from langchain_core.runnables.config import (
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.runnables.utils import Addable
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.runnables.test_runnable_events_v1 import (
@@ -2120,19 +2120,19 @@ async def test_sync_in_sync_lambdas() -> None:
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
class StreamingRunnable(Runnable[Input, Output]):
class StreamingRunnable(Runnable[Any, Addable]):
"""A custom runnable used for testing purposes."""
iterable: Iterable[Any]
iterable: Iterable[Addable]
def __init__(self, iterable: Iterable[Any]) -> None:
def __init__(self, iterable: Iterable[Addable]) -> None:
"""Initialize the runnable."""
self.iterable = iterable
@override
def invoke(
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
) -> Output:
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Addable:
"""Invoke the runnable."""
msg = "Server side error"
raise ValueError(msg)
@@ -2140,19 +2140,19 @@ class StreamingRunnable(Runnable[Input, Output]):
@override
def stream(
self,
input: Input,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Output]:
) -> Iterator[Addable]:
raise NotImplementedError
@override
async def astream(
self,
input: Input,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
) -> AsyncIterator[Addable]:
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
@@ -2187,7 +2187,7 @@ class StreamingRunnable(Runnable[Input, Output]):
async def test_astream_events_from_custom_runnable() -> None:
"""Test astream events from a custom runnable."""
iterator = ["1", "2", "3"]
runnable: Runnable[int, str] = StreamingRunnable(iterator)
runnable = StreamingRunnable(iterator)
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
assert chunks == ["1", "2", "3"]
events = await _collect_events(runnable.astream_events(1, version="v2"))
@@ -2390,7 +2390,7 @@ async def test_runnable_generator() -> None:
yield "1"
yield "2"
runnable: Runnable[str, str] = RunnableGenerator(transform=generator)
runnable = RunnableGenerator(transform=generator)
events = await _collect_events(runnable.astream_events("hello", version="v2"))
_assert_events_equal_allow_superset_metadata(
events,