diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 5ace57e7634..717834b512a 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -25,6 +25,7 @@ from typing import ( List, Mapping, Optional, + Protocol, Sequence, Set, Tuple, @@ -5519,12 +5520,36 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): return attr +class _RunnableCallableSync(Protocol[Input, Output]): + def __call__(self, __in: Input, *, config: RunnableConfig) -> Output: ... + + +class _RunnableCallableAsync(Protocol[Input, Output]): + def __call__(self, __in: Input, *, config: RunnableConfig) -> Awaitable[Output]: ... + + +class _RunnableCallableIterator(Protocol[Input, Output]): + def __call__( + self, __in: Iterator[Input], *, config: RunnableConfig + ) -> Iterator[Output]: ... + + +class _RunnableCallableAsyncIterator(Protocol[Input, Output]): + def __call__( + self, __in: AsyncIterator[Input], *, config: RunnableConfig + ) -> AsyncIterator[Output]: ... + + RunnableLike = Union[ Runnable[Input, Output], Callable[[Input], Output], Callable[[Input], Awaitable[Output]], Callable[[Iterator[Input]], Iterator[Output]], Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + _RunnableCallableSync[Input, Output], + _RunnableCallableAsync[Input, Output], + _RunnableCallableIterator[Input, Output], + _RunnableCallableAsyncIterator[Input, Output], Mapping[str, Any], ]