core: Put Python version as a project requirement so it is considered by ruff (#26608)

Ruff doesn't know about the python version in
`[tool.poetry.dependencies]`. It can get it from
`project.requires-python`.

Notes:
* poetry seems to have issues getting the python constraints from
`requires-python` and using `python` in per dependency constraints. So I
had to duplicate the info. I will open an issue on poetry.
* `inspect.isclass()` doesn't work correctly with `GenericAlias`
(`list[...]`, `dict[..., ...]`) on Python <3.11 so I added some `not
isinstance(type, GenericAlias)` checks:

Python 3.11
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
False
```

Python 3.9
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
True
```

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet
2024-09-18 16:37:57 +02:00
committed by GitHub
parent 0f07cf61da
commit a47b332841
162 changed files with 920 additions and 1002 deletions

View File

@@ -6,36 +6,37 @@ import functools
import inspect
import threading
from abc import ABC, abstractmethod
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Awaitable,
Coroutine,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator
from langchain_core.load.serializable import (
@@ -340,7 +341,11 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@@ -408,7 +413,11 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@@ -771,10 +780,10 @@ class Runnable(Generic[Input, Output], ABC):
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(inputs[0], configs[0])])
return cast(list[Output], [invoke(inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
return cast(list[Output], list(executor.map(invoke, inputs, configs)))
@overload
def batch_as_completed(
@@ -2024,7 +2033,7 @@ class Runnable(Generic[Input, Output], ABC):
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in input])
return cast(list[Output], [e for _ in input])
else:
raise
else:
@@ -2036,7 +2045,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
return cast(list[Output], output)
else:
raise first_exception
@@ -2099,7 +2108,7 @@ class Runnable(Generic[Input, Output], ABC):
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
if return_exceptions:
return cast(List[Output], [e for _ in input])
return cast(list[Output], [e for _ in input])
else:
raise
else:
@@ -2113,7 +2122,7 @@ class Runnable(Generic[Input, Output], ABC):
coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
return cast(list[Output], output)
else:
raise first_exception
@@ -3171,7 +3180,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for rm in run_managers:
rm.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
return cast(list[Output], [e for _ in inputs])
else:
raise
else:
@@ -3183,7 +3192,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
else:
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
return cast(list[Output], inputs)
else:
raise first_exception
@@ -3298,7 +3307,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
return cast(list[Output], [e for _ in inputs])
else:
raise
else:
@@ -3312,7 +3321,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
return cast(list[Output], inputs)
else:
raise first_exception
@@ -3420,7 +3429,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
yield chunk
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping
of their outputs.
@@ -4071,7 +4080,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@@ -4106,7 +4119,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@@ -4369,7 +4386,7 @@ class RunnableLambda(Runnable[Input, Output]):
module = getattr(func, "__module__", None)
return create_model_v2(
self.get_name("Input"),
root=List[Any],
root=list[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
@@ -4420,7 +4437,11 @@ class RunnableLambda(Runnable[Input, Output]):
func = getattr(self, "func", None) or self.afunc
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@@ -4921,7 +4942,7 @@ class RunnableLambda(Runnable[Input, Output]):
yield chunk
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
"""Runnable that delegates calls to another Runnable
with each element of the input sequence.
@@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
return list[self.bound.InputType] # type: ignore[name-defined]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
@@ -4946,7 +4967,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return create_model_v2(
self.get_name("Input"),
root=(
List[self.bound.get_input_schema(config)], # type: ignore
list[self.bound.get_input_schema(config)], # type: ignore
None,
),
# create model needs access to appropriate type annotations to be
@@ -4961,7 +4982,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property
def OutputType(self) -> type[list[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
return list[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema(
self, config: Optional[RunnableConfig] = None
@@ -4969,7 +4990,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config)
return create_model_v2(
self.get_name("Output"),
root=List[schema], # type: ignore[valid-type]
root=list[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
@@ -5255,7 +5276,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property
def InputType(self) -> type[Input]:
return (
cast(Type[Input], self.custom_input_type)
cast(type[Input], self.custom_input_type)
if self.custom_input_type is not None
else self.bound.InputType
)
@@ -5263,7 +5284,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property
def OutputType(self) -> type[Output]:
return (
cast(Type[Output], self.custom_output_type)
cast(type[Output], self.custom_output_type)
if self.custom_output_type is not None
else self.bound.OutputType
)
@@ -5336,7 +5357,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@@ -5358,7 +5379,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@@ -5400,7 +5421,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> Iterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@@ -5451,7 +5472,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:

View File

@@ -1,15 +1,8 @@
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
@@ -149,13 +142,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
@@ -172,7 +165,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().get_input_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import uuid
import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
@@ -10,14 +11,8 @@ from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
@@ -43,7 +38,7 @@ if TYPE_CHECKING:
else:
# Pydantic validates through typed dicts, but
# the callbacks need forward refs updated
Callbacks = Optional[Union[List, Any]]
Callbacks = Optional[Union[list, Any]]
class EmptyDict(TypedDict, total=False):

View File

@@ -3,20 +3,16 @@ from __future__ import annotations
import enum
import threading
from abc import abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import Mapping as Mapping
from functools import wraps
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from typing import Mapping as Mapping
from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict
@@ -176,10 +172,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0])])
return cast(list[Output], [invoke(prepared[0], inputs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
return cast(list[Output], list(executor.map(invoke, prepared, inputs)))
async def abatch(
self,
@@ -562,7 +558,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
for v in list(self.alternatives.keys()) + [self.default_key]
),
)
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
_enums_for_spec[self.which] = cast(type[StrEnum], which_enum)
return get_unique_config_specs(
# which alternative
[
@@ -694,7 +690,7 @@ def make_options_spec(
spec.name or spec.id,
((v, v) for v in list(spec.options.keys())),
)
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
_enums_for_spec[spec] = cast(type[StrEnum], enum)
if isinstance(spec, ConfigurableFieldSingleOption):
return ConfigurableFieldSpec(
id=spec.id,

View File

@@ -1,19 +1,13 @@
import asyncio
import inspect
import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@@ -96,7 +90,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""The Runnable to run first."""
fallbacks: Sequence[Runnable[Input, Output]]
"""A sequence of fallbacks to try."""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,)
"""The exceptions on which fallbacks should be tried.
Any exception that is not a subclass of these exceptions will be raised immediately.
@@ -112,25 +106,25 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
return self.runnable.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return self.runnable.get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]
@@ -142,7 +136,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@@ -252,12 +246,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all(
@@ -296,9 +290,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for cm, input, config in zip(callback_managers, inputs, configs)
]
to_return: Dict[int, Any] = {}
to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = runnable.batch(
@@ -344,12 +338,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all(
@@ -378,7 +372,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
None,
@@ -392,7 +386,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = await runnable.abatch(

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import inspect
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@@ -11,7 +12,6 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Sequence,
TypedDict,
Union,
overload,

View File

@@ -3,7 +3,8 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
import math
import os
from typing import Any, Mapping, Sequence
from collections.abc import Mapping, Sequence
from typing import Any
from langchain_core.runnables.graph import Edge as LangEdge

View File

@@ -1,7 +1,7 @@
import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional
from typing import Optional
from langchain_core.runnables.graph import (
CurveStyle,
@@ -15,8 +15,8 @@ MARKDOWN_SPECIAL_CHARS = "*_`"
def draw_mermaid(
nodes: Dict[str, Node],
edges: List[Edge],
nodes: dict[str, Node],
edges: list[Edge],
*,
first_node: Optional[str] = None,
last_node: Optional[str] = None,
@@ -87,7 +87,7 @@ def draw_mermaid(
mermaid_graph += f"\t{node_label}\n"
# Group edges by their common prefixes
edge_groups: Dict[str, List[Edge]] = {}
edge_groups: dict[str, list[Edge]] = {}
for edge in edges:
src_parts = edge.source.split(":")
tgt_parts = edge.target.split(":")
@@ -98,7 +98,7 @@ def draw_mermaid(
seen_subgraphs = set()
def add_subgraph(edges: List[Edge], prefix: str) -> None:
def add_subgraph(edges: list[Edge], prefix: str) -> None:
nonlocal mermaid_graph
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
if prefix and not self_loop:

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import inspect
from collections.abc import Sequence
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Union,
)
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from langchain_core.tracers.schemas import Run
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
@@ -419,7 +419,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(

View File

@@ -5,15 +5,11 @@ from __future__ import annotations
import asyncio
import inspect
import threading
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Union,
cast,
@@ -349,7 +345,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
The `RunnableAssign` class takes input dictionaries and, through a
@@ -564,7 +560,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
yield cast(dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
@@ -650,7 +646,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
yield chunk
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that picks keys from Dict[str, Any] inputs.
RunnablePick class represents a Runnable that selectively picks keys from a

View File

@@ -1,11 +1,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
@@ -98,7 +94,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
retryable_chain = chain.with_retry()
"""
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
"""The exception types to retry on. By default all exceptions are retried.
In general you should only retry on exceptions that are likely to be
@@ -115,13 +111,13 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
"""The maximum number of attempts to retry the Runnable."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()
def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: dict[str, Any] = dict()
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
@@ -152,10 +148,10 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List["T"],
config: list[RunnableConfig],
run_manager: list["T"],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
) -> list[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
@@ -208,17 +204,17 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _batch(
self,
inputs: List[Input],
run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig],
inputs: list[Input],
run_manager: list["CallbackManagerForChainRun"],
config: list[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
not_set: list[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
@@ -250,9 +246,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result)
except RetryError as e:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
@@ -262,29 +258,29 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
async def _abatch(
self,
inputs: List[Input],
run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig],
inputs: list[Input],
run_manager: list["AsyncCallbackManagerForChainRun"],
config: list[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
not_set: list[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
@@ -316,9 +312,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result)
except RetryError as e:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
@@ -328,12 +324,12 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)

View File

@@ -1,12 +1,9 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Optional,
Union,
cast,
@@ -154,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
configs = get_config_list(config, len(inputs))
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)

View File

@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Literal, Sequence, Union
from collections.abc import Sequence
from typing import Any, Literal, Union
from typing_extensions import NotRequired, TypedDict

View File

@@ -6,23 +6,24 @@ import ast
import asyncio
import inspect
import textwrap
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Coroutine,
Iterable,
Mapping,
Sequence,
)
from functools import lru_cache
from inspect import signature
from itertools import groupby
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Iterable,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
)
@@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]):
class AddableDict(dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""