diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 21ec7caad38..ffb0c415b80 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -25,6 +25,7 @@ from typing import ( List, Mapping, Optional, + Protocol, Sequence, Set, Tuple, @@ -5519,12 +5520,36 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): return attr +class _SyncSingle(Protocol[Input, Output]): + def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ... + + +class _AsyncSingle(Protocol[Input, Output]): + def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ... + + +class _SyncIterator(Protocol[Input, Output]): + def __call__( + self, __in: Iterator[Input], *, config: RunnableConfig + ) -> Iterator[Output]: ... + + +class _AsyncIterator(Protocol[Input, Output]): + def __call__( + self, __in: AsyncIterator[Input], *, config: RunnableConfig + ) -> AsyncIterator[Output]: ... + + RunnableLike = Union[ Runnable[Input, Output], Callable[[Input], Output], Callable[[Input], Awaitable[Output]], Callable[[Iterator[Input]], Iterator[Output]], Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + _SyncSingle[Any, Any], + _AsyncSingle[Any, Any], + _SyncIterator[Any, Any], + _AsyncIterator[Any, Any], Mapping[str, Any], ]