diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 43b025ad441..4a70e1b545c 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -642,8 +642,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -653,8 +653,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Any, @@ -663,8 +663,8 @@ class Runnable(Generic[Input, Output], ABC): def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], @@ -746,8 +746,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], @@ -757,8 +757,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Optional[Any], @@ -767,8 +767,8 @@ class Runnable(Generic[Input, Output], ABC): async def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], @@ -4506,8 +4506,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -4517,8 +4517,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Any, @@ -4527,13 +4527,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> Iterator[Tuple[int, Union[Output, Exception]]]: - if isinstance(config, list): + if isinstance(config, Sequence): configs = cast( List[RunnableConfig], [self._merge_configs(conf) for conf in config], @@ -4559,8 +4559,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], @@ -4570,8 +4570,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Optional[Any], @@ -4580,13 +4580,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): async def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: - if isinstance(config, list): + if isinstance(config, Sequence): configs = cast( List[RunnableConfig], [self._merge_configs(conf) for conf in config], diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index f69a0b00000..c9e7904a5a8 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -18,6 +18,7 @@ from typing import ( Iterator, List, Optional, + Sequence, TypeVar, Union, cast, @@ -159,7 +160,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def get_config_list( - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int ) -> List[RunnableConfig]: """Get a list of configs from a single config or a list of configs. @@ -179,13 +180,13 @@ def get_config_list( """ if length < 0: raise ValueError(f"length must be >= 0, but got {length}") - if isinstance(config, list) and len(config) != length: + if isinstance(config, Sequence) and len(config) != length: raise ValueError( f"config must be a list of the same length as inputs, " f"but got {len(config)} configs for {length} inputs" ) - if isinstance(config, list): + if isinstance(config, Sequence): return list(map(ensure_config, config)) if length > 1 and isinstance(config, dict) and config.get("run_id") is not None: warnings.warn(