From 58b118544e19471392e2039a3623ade96cf1edb4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 4 Jun 2024 08:04:09 -0700 Subject: [PATCH] Use immutable sequence type for batch/batch_as_completed types (#22433) Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --- libs/core/langchain_core/runnables/base.py | 52 ++++++++++---------- libs/core/langchain_core/runnables/config.py | 7 +-- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 43b025ad441..4a70e1b545c 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -642,8 +642,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -653,8 +653,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Any, @@ -663,8 +663,8 @@ class Runnable(Generic[Input, Output], ABC): def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], @@ -746,8 +746,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], @@ -757,8 +757,8 @@ class Runnable(Generic[Input, Output], ABC): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Optional[Any], @@ -767,8 +767,8 @@ class Runnable(Generic[Input, Output], ABC): async def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], @@ -4506,8 +4506,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -4517,8 +4517,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Any, @@ -4527,13 +4527,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): def batch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> Iterator[Tuple[int, Union[Output, Exception]]]: - if isinstance(config, list): + if isinstance(config, Sequence): configs = cast( List[RunnableConfig], [self._merge_configs(conf) for conf in config], @@ -4559,8 +4559,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], @@ -4570,8 +4570,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): @overload def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: Literal[True], **kwargs: Optional[Any], @@ -4580,13 +4580,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): async def abatch_as_completed( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: Sequence[Input], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: - if isinstance(config, list): + if isinstance(config, Sequence): configs = cast( List[RunnableConfig], [self._merge_configs(conf) for conf in config], diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index f69a0b00000..c9e7904a5a8 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -18,6 +18,7 @@ from typing import ( Iterator, List, Optional, + Sequence, TypeVar, Union, cast, @@ -159,7 +160,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def get_config_list( - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int ) -> List[RunnableConfig]: """Get a list of configs from a single config or a list of configs. @@ -179,13 +180,13 @@ def get_config_list( """ if length < 0: raise ValueError(f"length must be >= 0, but got {length}") - if isinstance(config, list) and len(config) != length: + if isinstance(config, Sequence) and len(config) != length: raise ValueError( f"config must be a list of the same length as inputs, " f"but got {len(config)} configs for {length} inputs" ) - if isinstance(config, list): + if isinstance(config, Sequence): return list(map(ensure_config, config)) if length > 1 and isinstance(config, dict) and config.get("run_id") is not None: warnings.warn(