mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
chore(core): improve types for StreamingRunnable (#34540)
This commit is contained in:
committed by
GitHub
parent
5e9765d811
commit
2a2a4067ca
@@ -46,7 +46,7 @@ from langchain_core.runnables.config import (
|
||||
)
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.runnables.utils import Addable
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.aiter import aclosing
|
||||
from tests.unit_tests.runnables.test_runnable_events_v1 import (
|
||||
@@ -2120,19 +2120,19 @@ async def test_sync_in_sync_lambdas() -> None:
|
||||
_assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS)
|
||||
|
||||
|
||||
class StreamingRunnable(Runnable[Input, Output]):
|
||||
class StreamingRunnable(Runnable[Any, Addable]):
|
||||
"""A custom runnable used for testing purposes."""
|
||||
|
||||
iterable: Iterable[Any]
|
||||
iterable: Iterable[Addable]
|
||||
|
||||
def __init__(self, iterable: Iterable[Any]) -> None:
|
||||
def __init__(self, iterable: Iterable[Addable]) -> None:
|
||||
"""Initialize the runnable."""
|
||||
self.iterable = iterable
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Output:
|
||||
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Addable:
|
||||
"""Invoke the runnable."""
|
||||
msg = "Server side error"
|
||||
raise ValueError(msg)
|
||||
@@ -2140,19 +2140,19 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Output]:
|
||||
) -> Iterator[Addable]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Output]:
|
||||
) -> AsyncIterator[Addable]:
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@@ -2187,7 +2187,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
async def test_astream_events_from_custom_runnable() -> None:
|
||||
"""Test astream events from a custom runnable."""
|
||||
iterator = ["1", "2", "3"]
|
||||
runnable: Runnable[int, str] = StreamingRunnable(iterator)
|
||||
runnable = StreamingRunnable(iterator)
|
||||
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
|
||||
assert chunks == ["1", "2", "3"]
|
||||
events = await _collect_events(runnable.astream_events(1, version="v2"))
|
||||
@@ -2390,7 +2390,7 @@ async def test_runnable_generator() -> None:
|
||||
yield "1"
|
||||
yield "2"
|
||||
|
||||
runnable: Runnable[str, str] = RunnableGenerator(transform=generator)
|
||||
runnable = RunnableGenerator(transform=generator)
|
||||
events = await _collect_events(runnable.astream_events("hello", version="v2"))
|
||||
_assert_events_equal_allow_superset_metadata(
|
||||
events,
|
||||
|
||||
Reference in New Issue
Block a user