core: Put Python version as a project requirement so it is considered by ruff (#26608)

Ruff doesn't know about the python version in
`[tool.poetry.dependencies]`. It can get it from
`project.requires-python`.

Notes:
* poetry seems to have issues getting the python constraints from
`requires-python` and using `python` in per dependency constraints. So I
had to duplicate the info. I will open an issue on poetry.
* `inspect.isclass()` doesn't work correctly with `GenericAlias`
(`list[...]`, `dict[..., ...]`) on Python <3.11 so I added some `not
isinstance(type, GenericAlias)` checks:

Python 3.11
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
False
```

Python 3.9
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
True
```

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-18 16:37:57 +02:00 committed by GitHub
parent 0f07cf61da
commit a47b332841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
162 changed files with 920 additions and 1002 deletions

View File

@ -14,7 +14,8 @@ import contextlib
import functools import functools
import inspect import inspect
import warnings import warnings
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast from collections.abc import Generator
from typing import Any, Callable, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal from langchain_core._api.internal import is_caller_internal
@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
# PUBLIC API # PUBLIC API
T = TypeVar("T", bound=Union[Callable[..., Any], Type]) T = TypeVar("T", bound=Union[Callable[..., Any], type])
def beta( def beta(

View File

@ -14,11 +14,10 @@ import contextlib
import functools import functools
import inspect import inspect
import warnings import warnings
from collections.abc import Generator
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generator,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# Last Any should be FieldInfoV1 but this leads to circular imports # Last Any should be FieldInfoV1 but this leads to circular imports
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any]) T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
def _validate_deprecation_params( def _validate_deprecation_params(
@ -262,7 +261,7 @@ def deprecated(
if not _obj_type: if not _obj_type:
_obj_type = "attribute" _obj_type = "attribute"
wrapped = None wrapped = None
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__ _name = _name or cast(Union[type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__ old_doc = obj.__doc__
class _deprecated_property(property): class _deprecated_property(property):
@ -304,7 +303,7 @@ def deprecated(
) )
else: else:
_name = _name or cast(Union[Type, Callable], obj).__qualname__ _name = _name or cast(Union[type, Callable], obj).__qualname__
if not _obj_type: if not _obj_type:
# edge case: when a function is within another function # edge case: when a function is within another function
# within a test, this will call it a "method" not a "function" # within a test, this will call it a "method" not a "function"

View File

@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Literal, Sequence, Union from collections.abc import Sequence
from typing import Any, Literal, Union
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (

View File

@ -1,19 +1,13 @@
import asyncio import asyncio
import threading import threading
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import groupby from itertools import groupby
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
DefaultDict,
Dict,
List,
Mapping,
Optional, Optional,
Sequence,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T") T = TypeVar("T")
Values = Dict[Union[asyncio.Event, threading.Event], Any] Values = dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/" CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get" CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set" CONTEXT_CONFIG_SUFFIX_SET = "/set"
@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
def _config_with_context( def _config_with_context(
config: RunnableConfig, config: RunnableConfig,
steps: List[Runnable], steps: list[Runnable],
setter: Callable, setter: Callable,
getter: Callable, getter: Callable,
event_cls: Union[Type[threading.Event], Type[asyncio.Event]], event_cls: Union[type[threading.Event], type[asyncio.Event]],
) -> RunnableConfig: ) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})): if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config return config
@ -99,10 +93,10 @@ def _config_with_context(
} }
values: Values = {} values: Values = {}
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict( events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls event_cls
) )
context_funcs: Dict[str, Callable[[], Any]] = {} context_funcs: dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items(): for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)] getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)] setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
@ -129,7 +123,7 @@ def _config_with_context(
def aconfig_with_context( def aconfig_with_context(
config: RunnableConfig, config: RunnableConfig,
steps: List[Runnable], steps: list[Runnable],
) -> RunnableConfig: ) -> RunnableConfig:
"""Asynchronously patch a runnable config with context getters and setters. """Asynchronously patch a runnable config with context getters and setters.
@ -145,7 +139,7 @@ def aconfig_with_context(
def config_with_context( def config_with_context(
config: RunnableConfig, config: RunnableConfig,
steps: List[Runnable], steps: list[Runnable],
) -> RunnableConfig: ) -> RunnableConfig:
"""Patch a runnable config with context getters and setters. """Patch a runnable config with context getters and setters.
@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
prefix: str = "" prefix: str = ""
key: Union[str, List[str]] key: Union[str, list[str]]
def __str__(self) -> str: def __str__(self) -> str:
return f"ContextGet({_print_keys(self.key)})" return f"ContextGet({_print_keys(self.key)})"
@property @property
def ids(self) -> List[str]: def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else "" prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key] keys = self.key if isinstance(self.key, list) else [self.key]
return [ return [
@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
] ]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return super().config_specs + [ return super().config_specs + [
ConfigurableFieldSpec( ConfigurableFieldSpec(
id=id_, id=id_,
@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
return f"ContextSet({_print_keys(list(self.keys.keys()))})" return f"ContextSet({_print_keys(list(self.keys.keys()))})"
@property @property
def ids(self) -> List[str]: def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else "" prefix = self.prefix + "/" if self.prefix else ""
return [ return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}" f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
] ]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
mapper_config_specs = [ mapper_config_specs = [
s s
for mapper in self.keys.values() for mapper in self.keys.values()
@ -364,7 +358,7 @@ class Context:
return PrefixContext(prefix=scope) return PrefixContext(prefix=scope)
@staticmethod @staticmethod
def getter(key: Union[str, List[str]], /) -> ContextGet: def getter(key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key) return ContextGet(key=key)
@staticmethod @staticmethod
@ -385,7 +379,7 @@ class PrefixContext:
def __init__(self, prefix: str = ""): def __init__(self, prefix: str = ""):
self.prefix = prefix self.prefix = prefix
def getter(self, key: Union[str, List[str]], /) -> ContextGet: def getter(self, key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix) return ContextGet(key=key, prefix=self.prefix)
def setter( def setter(

View File

@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import UUID from uuid import UUID
from tenacity import RetryCallState from tenacity import RetryCallState
@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_metadata.pop(key) self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]

View File

@ -5,19 +5,15 @@ import functools
import logging import logging
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import copy_context from contextvars import copy_context
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator,
Callable, Callable,
Coroutine,
Generator,
Optional, Optional,
Sequence,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -2352,7 +2348,7 @@ def _configure(
and handler_class is not None and handler_class is not None
) )
if var.get() is not None or create_one: if var.get() is not None or create_one:
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)() var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
if handler_class is None: if handler_class is None:
if not any( if not any(
handler is var_handler # direct pointer comparison handler is var_handler # direct pointer comparison

View File

@ -18,7 +18,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Sequence, Union from collections.abc import Sequence
from typing import Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterator, List from collections.abc import Iterator
from langchain_core.chat_sessions import ChatSession from langchain_core.chat_sessions import ChatSession
@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
An iterator of chat sessions. An iterator of chat sessions.
""" """
def load(self) -> List[ChatSession]: def load(self) -> list[ChatSession]:
"""Eagerly load the chat sessions into memory. """Eagerly load the chat sessions into memory.
Returns: Returns:

View File

@ -1,6 +1,7 @@
"""**Chat Sessions** are a collection of messages and function calls.""" """**Chat Sessions** are a collection of messages and function calls."""
from typing import Sequence, TypedDict from collections.abc import Sequence
from typing import TypedDict
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor

View File

@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterable from collections.abc import Iterable
# Re-export Blob and PathLike for backwards compatibility # Re-export Blob and PathLike for backwards compatibility
from langchain_core.documents.base import Blob as Blob from langchain_core.documents.base import Blob as Blob

View File

@ -1,7 +1,8 @@
import datetime import datetime
import json import json
import uuid import uuid
from typing import Any, Callable, Iterator, Optional, Sequence, Union from collections.abc import Iterator, Sequence
from typing import Any, Callable, Optional, Union
from langsmith import Client as LangSmithClient from langsmith import Client as LangSmithClient

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import contextlib import contextlib
import mimetypes import mimetypes
from collections.abc import Generator
from io import BufferedReader, BytesIO from io import BufferedReader, BytesIO
from pathlib import PurePath from pathlib import PurePath
from typing import Any, Generator, Literal, Optional, Union, cast from typing import Any, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Sequence from collections.abc import Sequence
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor

View File

@ -1,7 +1,6 @@
"""**Embeddings** interface.""" """**Embeddings** interface."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
@ -35,7 +34,7 @@ class Embeddings(ABC):
""" """
@abstractmethod @abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs. """Embed search docs.
Args: Args:
@ -46,7 +45,7 @@ class Embeddings(ABC):
""" """
@abstractmethod @abstractmethod
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text. """Embed query text.
Args: Args:
@ -56,7 +55,7 @@ class Embeddings(ABC):
Embedding. Embedding.
""" """
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs. """Asynchronous Embed search docs.
Args: Args:
@ -67,7 +66,7 @@ class Embeddings(ABC):
""" """
return await run_in_executor(None, self.embed_documents, texts) return await run_in_executor(None, self.embed_documents, texts)
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text. """Asynchronous Embed query text.
Args: Args:

View File

@ -2,7 +2,6 @@
# Please do not add additional fake embedding model implementations here. # Please do not add additional fake embedding model implementations here.
import hashlib import hashlib
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
size: int size: int
"""The size of the embedding vector.""" """The size of the embedding vector."""
def _get_embedding(self) -> List[float]: def _get_embedding(self) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped] import numpy as np # type: ignore[import-not-found, import-untyped]
return list(np.random.normal(size=self.size)) return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding() for _ in texts] return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
return self._get_embedding() return self._get_embedding()
@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
size: int size: int
"""The size of the embedding vector.""" """The size of the embedding vector."""
def _get_embedding(self, seed: int) -> List[float]: def _get_embedding(self, seed: int) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped] import numpy as np # type: ignore[import-not-found, import-untyped]
# set the seed for the random generator # set the seed for the random generator
@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""Get a seed for the random generator, using the hash of the text.""" """Get a seed for the random generator, using the hash of the text."""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts] return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
return self._get_embedding(seed=self._get_seed(text)) return self._get_embedding(seed=self._get_seed(text))

View File

@ -1,7 +1,7 @@
"""Interface for selecting examples to include in prompts.""" """Interface for selecting examples to include in prompts."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
"""Interface for selecting examples to include in prompts.""" """Interface for selecting examples to include in prompts."""
@abstractmethod @abstractmethod
def add_example(self, example: Dict[str, str]) -> Any: def add_example(self, example: dict[str, str]) -> Any:
"""Add new example to store. """Add new example to store.
Args: Args:
example: A dictionary with keys as input variables example: A dictionary with keys as input variables
and values as their values.""" and values as their values."""
async def aadd_example(self, example: Dict[str, str]) -> Any: async def aadd_example(self, example: dict[str, str]) -> Any:
"""Async add new example to store. """Async add new example to store.
Args: Args:
@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
return await run_in_executor(None, self.add_example, example) return await run_in_executor(None, self.add_example, example)
@abstractmethod @abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select which examples to use based on the inputs. """Select which examples to use based on the inputs.
Args: Args:
input_variables: A dictionary with keys as input variables input_variables: A dictionary with keys as input variables
and values as their values.""" and values as their values."""
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the inputs. """Async select which examples to use based on the inputs.
Args: Args:

View File

@ -1,7 +1,7 @@
"""Select examples based on length.""" """Select examples based on length."""
import re import re
from typing import Callable, Dict, List from typing import Callable
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel): class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
"""Select examples based on length.""" """Select examples based on length."""
examples: List[dict] examples: list[dict]
"""A list of the examples that the prompt template expects.""" """A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate example_prompt: PromptTemplate
@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
max_length: int = 2048 max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut.""" """Max length for the prompt, beyond which examples are cut."""
example_text_lengths: List[int] = Field(default_factory=list) # :meta private: example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
"""Length of each example.""" """Length of each example."""
def add_example(self, example: Dict[str, str]) -> None: def add_example(self, example: dict[str, str]) -> None:
"""Add new example to list. """Add new example to list.
Args: Args:
@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
string_example = self.example_prompt.format(**example) string_example = self.example_prompt.format(**example)
self.example_text_lengths.append(self.get_text_length(string_example)) self.example_text_lengths.append(self.get_text_length(string_example))
async def aadd_example(self, example: Dict[str, str]) -> None: async def aadd_example(self, example: dict[str, str]) -> None:
"""Async add new example to list. """Async add new example to list.
Args: Args:
@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples] self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
return self return self
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select which examples to use based on the input lengths. """Select which examples to use based on the input lengths.
Args: Args:
@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
i += 1 i += 1
return examples return examples
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the input lengths. """Async select which examples to use based on the input lengths.
Args: Args:

View File

@ -1,13 +1,10 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import ( from typing import (
Any, Any,
AsyncIterable,
ClassVar, ClassVar,
Collection,
Iterable,
Iterator,
Optional, Optional,
) )

View File

@ -1,5 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Literal, Union from typing import Literal, Union
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.documents import Document from langchain_core.documents import Document
@ -41,7 +42,7 @@ METADATA_LINKS_KEY = "links"
@beta() @beta()
def get_links(doc: Document) -> List[Link]: def get_links(doc: Document) -> list[Link]:
"""Get the links from a document. """Get the links from a document.
Args: Args:

View File

@ -5,17 +5,13 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import uuid import uuid
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
from itertools import islice from itertools import islice
from typing import ( from typing import (
Any, Any,
AsyncIterable,
AsyncIterator,
Callable, Callable,
Iterable,
Iterator,
Literal, Literal,
Optional, Optional,
Sequence,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union, Union,

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import abc import abc
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, TypedDict from collections.abc import Sequence
from typing import Any, Optional, TypedDict
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.documents import Document from langchain_core.documents import Document

View File

@ -1,5 +1,6 @@
import uuid import uuid
from typing import Any, Dict, List, Optional, Sequence, cast from collections.abc import Sequence
from typing import Any, Optional, cast
from pydantic import Field from pydantic import Field
@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex):
.. versionadded:: 0.2.29 .. versionadded:: 0.2.29
""" """
store: Dict[str, Document] = Field(default_factory=dict) store: dict[str, Document] = Field(default_factory=dict)
top_k: int = 4 top_k: int = 4
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
@ -43,7 +44,7 @@ class InMemoryDocumentIndex(DocumentIndex):
return UpsertResponse(succeeded=ok_ids, failed=[]) return UpsertResponse(succeeded=ok_ids, failed=[])
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse: def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by ID.""" """Delete by ID."""
if ids is None: if ids is None:
raise ValueError("IDs must be provided for deletion") raise ValueError("IDs must be provided for deletion")
@ -59,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex):
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[] succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
) )
def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]: def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids.""" """Get by ids."""
found_documents = [] found_documents = []
@ -71,7 +72,7 @@ class InMemoryDocumentIndex(DocumentIndex):
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> list[Document]:
counts_by_doc = [] counts_by_doc = []
for document in self.store.values(): for document in self.store.values():

View File

@ -1,16 +1,14 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache from collections.abc import Mapping, Sequence
from functools import cache
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
List,
Literal, Literal,
Mapping,
Optional, Optional,
Sequence,
TypeVar, TypeVar,
Union, Union,
) )
@ -52,7 +50,7 @@ class LangSmithParams(TypedDict, total=False):
"""Stop words for generation.""" """Stop words for generation."""
@lru_cache(maxsize=None) # Cache the tokenizer @cache # Cache the tokenizer
def get_tokenizer() -> Any: def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance. """Get a GPT-2 tokenizer instance.
@ -158,7 +156,7 @@ class BaseLanguageModel(
return Union[ return Union[
str, str,
Union[StringPromptValue, ChatPromptValueConcrete], Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage], list[AnyMessage],
] ]
@abstractmethod @abstractmethod

View File

@ -3,21 +3,19 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import json import json
import typing
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
Literal, Literal,
Optional, Optional,
Sequence,
Union, Union,
cast, cast,
) )
@ -1121,18 +1119,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006 tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError() raise NotImplementedError()
def with_structured_output( def with_structured_output(
self, self,
schema: Union[Dict, type], # noqa: UP006 schema: Union[typing.Dict, type], # noqa: UP006
*, *,
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006 ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
Args: Args:

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import time import time
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -14,7 +15,7 @@ from langchain_core.runnables import RunnableConfig
class FakeListLLM(LLM): class FakeListLLM(LLM):
"""Fake LLM for testing purposes.""" """Fake LLM for testing purposes."""
responses: List[str] responses: list[str]
"""List of responses to return in order.""" """List of responses to return in order."""
# This parameter should be removed from FakeListLLM since # This parameter should be removed from FakeListLLM since
# it's only used by sub-classes. # it's only used by sub-classes.
@ -37,7 +38,7 @@ class FakeListLLM(LLM):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -52,7 +53,7 @@ class FakeListLLM(LLM):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -90,7 +91,7 @@ class FakeStreamingListLLM(FakeListLLM):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[str]: ) -> Iterator[str]:
result = self.invoke(input, config) result = self.invoke(input, config)
@ -110,7 +111,7 @@ class FakeStreamingListLLM(FakeListLLM):
input: LanguageModelInput, input: LanguageModelInput,
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
result = await self.ainvoke(input, config) result = await self.ainvoke(input, config)

View File

@ -3,7 +3,8 @@
import asyncio import asyncio
import re import re
import time import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -17,7 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
class FakeMessagesListChatModel(BaseChatModel): class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes.""" """Fake ChatModel for testing purposes."""
responses: List[BaseMessage] responses: list[BaseMessage]
"""List of responses to **cycle** through in order.""" """List of responses to **cycle** through in order."""
sleep: Optional[float] = None sleep: Optional[float] = None
"""Sleep time in seconds between responses.""" """Sleep time in seconds between responses."""
@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -51,7 +52,7 @@ class FakeListChatModelError(Exception):
class FakeListChatModel(SimpleChatModel): class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes.""" """Fake ChatModel for testing purposes."""
responses: List[str] responses: list[str]
"""List of responses to **cycle** through in order.""" """List of responses to **cycle** through in order."""
sleep: Optional[float] = None sleep: Optional[float] = None
i: int = 0 i: int = 0
@ -65,8 +66,8 @@ class FakeListChatModel(SimpleChatModel):
def _call( def _call(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -80,8 +81,8 @@ class FakeListChatModel(SimpleChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Union[List[str], None] = None, stop: Union[list[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None, run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -103,8 +104,8 @@ class FakeListChatModel(SimpleChatModel):
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Union[List[str], None] = None, stop: Union[list[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
@ -124,7 +125,7 @@ class FakeListChatModel(SimpleChatModel):
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses} return {"responses": self.responses}
@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel):
def _call( def _call(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
@ -142,8 +143,8 @@ class FakeChatModel(SimpleChatModel):
async def _agenerate( async def _agenerate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -157,7 +158,7 @@ class FakeChatModel(SimpleChatModel):
return "fake-chat-model" return "fake-chat-model"
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> dict[str, Any]:
return {"key": "fake"} return {"key": "fake"}
@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
@ -202,8 +203,8 @@ class GenericFakeChatModel(BaseChatModel):
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
@ -231,7 +232,7 @@ class GenericFakeChatModel(BaseChatModel):
# Use a regular expression to split on whitespace with a capture group # Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output. # so that we can preserve the whitespace in the output.
assert isinstance(content, str) assert isinstance(content, str)
content_chunks = cast(List[str], re.split(r"(\s)", content)) content_chunks = cast(list[str], re.split(r"(\s)", content))
for token in content_chunks: for token in content_chunks:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
@ -249,7 +250,7 @@ class GenericFakeChatModel(BaseChatModel):
for fkey, fvalue in value.items(): for fkey, fvalue in value.items():
if isinstance(fvalue, str): if isinstance(fvalue, str):
# Break function call by `,` # Break function call by `,`
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue)) fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks: for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel):
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: list[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:

View File

@ -10,16 +10,12 @@ import logging
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Optional, Optional,
Sequence,
Union, Union,
cast, cast,
) )
@ -448,7 +444,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
except Exception as e: except Exception as e:
if return_exceptions: if return_exceptions:
return cast(List[str], [e for _ in inputs]) return cast(list[str], [e for _ in inputs])
else: else:
raise e raise e
else: else:
@ -494,7 +490,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [g[0].text for g in llm_result.generations] return [g[0].text for g in llm_result.generations]
except Exception as e: except Exception as e:
if return_exceptions: if return_exceptions:
return cast(List[str], [e for _ in inputs]) return cast(list[str], [e for _ in inputs])
else: else:
raise e raise e
else: else:
@ -883,13 +879,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert run_name is None or ( assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts) isinstance(run_name, list) and len(run_name) == len(prompts)
) )
callbacks = cast(List[Callbacks], callbacks) callbacks = cast(list[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
metadata_list = cast( metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
) )
run_name_list = run_name or cast( run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts)) list[Optional[str]], ([None] * len(prompts))
) )
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
@ -910,9 +906,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
cast(Callbacks, callbacks), cast(Callbacks, callbacks),
self.callbacks, self.callbacks,
self.verbose, self.verbose,
cast(List[str], tags), cast(list[str], tags),
self.tags, self.tags,
cast(Dict[str, Any], metadata), cast(dict[str, Any], metadata),
self.metadata, self.metadata,
) )
] * len(prompts) ] * len(prompts)
@ -1116,13 +1112,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert run_name is None or ( assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts) isinstance(run_name, list) and len(run_name) == len(prompts)
) )
callbacks = cast(List[Callbacks], callbacks) callbacks = cast(list[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
metadata_list = cast( metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
) )
run_name_list = run_name or cast( run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts)) list[Optional[str]], ([None] * len(prompts))
) )
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
@ -1143,9 +1139,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
cast(Callbacks, callbacks), cast(Callbacks, callbacks),
self.callbacks, self.callbacks,
self.verbose, self.verbose,
cast(List[str], tags), cast(list[str], tags),
self.tags, self.tags,
cast(Dict[str, Any], metadata), cast(dict[str, Any], metadata),
self.metadata, self.metadata,
) )
] * len(prompts) ] * len(prompts)

View File

@ -1,7 +1,7 @@
import importlib import importlib
import json import json
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.load.mapping import ( from langchain_core.load.mapping import (
@ -34,11 +34,11 @@ class Reviver:
def __init__( def __init__(
self, self,
secrets_map: Optional[Dict[str, str]] = None, secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None, valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True, secrets_from_env: bool = True,
additional_import_mappings: Optional[ additional_import_mappings: Optional[
Dict[Tuple[str, ...], Tuple[str, ...]] dict[tuple[str, ...], tuple[str, ...]]
] = None, ] = None,
) -> None: ) -> None:
"""Initialize the reviver. """Initialize the reviver.
@ -73,7 +73,7 @@ class Reviver:
else ALL_SERIALIZABLE_MAPPINGS else ALL_SERIALIZABLE_MAPPINGS
) )
def __call__(self, value: Dict[str, Any]) -> Any: def __call__(self, value: dict[str, Any]) -> Any:
if ( if (
value.get("lc", None) == 1 value.get("lc", None) == 1
and value.get("type", None) == "secret" and value.get("type", None) == "secret"
@ -154,10 +154,10 @@ class Reviver:
def loads( def loads(
text: str, text: str,
*, *,
secrets_map: Optional[Dict[str, str]] = None, secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None, valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True, secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
) -> Any: ) -> Any:
"""Revive a LangChain class from a JSON string. """Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`. Equivalent to `load(json.loads(text))`.
@ -190,10 +190,10 @@ def loads(
def load( def load(
obj: Any, obj: Any,
*, *,
secrets_map: Optional[Dict[str, str]] = None, secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None, valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True, secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None, additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
) -> Any: ) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already """Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`. have a parsed JSON object, eg. from `json.load` or `orjson.loads`.

View File

@ -18,11 +18,9 @@ The mapping allows us to deserialize an AIMessage created with an older
version of LangChain where the code was in a different location. version of LangChain where the code was in a different location.
""" """
from typing import Dict, Tuple
# First value is the value that it is serialized as # First value is the value that it is serialized as
# Second value is the path to load it from # Second value is the path to load it from
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain", "schema", "messages", "AIMessage"): ( ("langchain", "schema", "messages", "AIMessage"): (
"langchain_core", "langchain_core",
"messages", "messages",
@ -535,7 +533,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
# Needed for backwards compatibility for old versions of LangChain where things # Needed for backwards compatibility for old versions of LangChain where things
# Were in different place # Were in different place
_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { _OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain", "schema", "AIMessage"): ( ("langchain", "schema", "AIMessage"): (
"langchain_core", "langchain_core",
"messages", "messages",
@ -583,7 +581,7 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
# Needed for backwards compatibility for a few versions where we serialized # Needed for backwards compatibility for a few versions where we serialized
# with langchain_core paths. # with langchain_core paths.
OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain_core", "messages", "ai", "AIMessage"): ( ("langchain_core", "messages", "ai", "AIMessage"): (
"langchain_core", "langchain_core",
"messages", "messages",
@ -937,7 +935,7 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
), ),
} }
_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = { _JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain_core", "messages", "AIMessage"): ( ("langchain_core", "messages", "AIMessage"): (
"langchain_core", "langchain_core",
"messages", "messages",

View File

@ -1,8 +1,6 @@
from abc import ABC from abc import ABC
from typing import ( from typing import (
Any, Any,
Dict,
List,
Literal, Literal,
Optional, Optional,
TypedDict, TypedDict,
@ -25,9 +23,9 @@ class BaseSerialized(TypedDict):
""" """
lc: int lc: int
id: List[str] id: list[str]
name: NotRequired[str] name: NotRequired[str]
graph: NotRequired[Dict[str, Any]] graph: NotRequired[dict[str, Any]]
class SerializedConstructor(BaseSerialized): class SerializedConstructor(BaseSerialized):
@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized):
""" """
type: Literal["constructor"] type: Literal["constructor"]
kwargs: Dict[str, Any] kwargs: dict[str, Any]
class SerializedSecret(BaseSerialized): class SerializedSecret(BaseSerialized):
@ -125,7 +123,7 @@ class Serializable(BaseModel, ABC):
return False return False
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the For example, if the class is `langchain.llms.openai.OpenAI`, then the
@ -134,7 +132,7 @@ class Serializable(BaseModel, ABC):
return cls.__module__.split(".") return cls.__module__.split(".")
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> dict[str, str]:
"""A map of constructor argument names to secret ids. """A map of constructor argument names to secret ids.
For example, For example,
@ -143,7 +141,7 @@ class Serializable(BaseModel, ABC):
return dict() return dict()
@property @property
def lc_attributes(self) -> Dict: def lc_attributes(self) -> dict:
"""List of attribute names that should be included in the serialized kwargs. """List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor. These attributes must be accepted by the constructor.
@ -152,7 +150,7 @@ class Serializable(BaseModel, ABC):
return {} return {}
@classmethod @classmethod
def lc_id(cls) -> List[str]: def lc_id(cls) -> list[str]:
"""A unique identifier for this class for serialization purposes. """A unique identifier for this class for serialization purposes.
The unique identifier is a list of strings that describes the path The unique identifier is a list of strings that describes the path
@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
def _replace_secrets( def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str] root: dict[Any, Any], secrets_map: dict[str, str]
) -> Dict[Any, Any]: ) -> dict[Any, Any]:
result = root.copy() result = root.copy()
for path, secret_id in secrets_map.items(): for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".") [*parts, last] = path.split(".")
@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
Returns: Returns:
SerializedNotImplemented SerializedNotImplemented
""" """
_id: List[str] = [] _id: list[str] = []
try: try:
if hasattr(obj, "__name__"): if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__] _id = [*obj.__module__.split("."), obj.__name__]

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from pydantic import model_validator from pydantic import model_validator
from typing_extensions import Self, TypedDict from typing_extensions import Self, TypedDict
@ -69,9 +69,9 @@ class AIMessage(BaseMessage):
At the moment, this is ignored by most models. Usage is discouraged. At the moment, this is ignored by most models. Usage is discouraged.
""" """
tool_calls: List[ToolCall] = [] tool_calls: list[ToolCall] = []
"""If provided, tool calls associated with the message.""" """If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = [] invalid_tool_calls: list[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message.""" """If provided, tool calls with parsing errors associated with the message."""
usage_metadata: Optional[UsageMetadata] = None usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts. """If provided, usage metadata for a message, such as token counts.
@ -83,7 +83,7 @@ class AIMessage(BaseMessage):
"""The type of the message (used for deserialization). Defaults to "ai".""" """The type of the message (used for deserialization). Defaults to "ai"."""
def __init__( def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None: ) -> None:
"""Pass in content as positional arg. """Pass in content as positional arg.
@ -94,7 +94,7 @@ class AIMessage(BaseMessage):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Returns: Returns:
@ -104,7 +104,7 @@ class AIMessage(BaseMessage):
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@property @property
def lc_attributes(self) -> Dict: def lc_attributes(self) -> dict:
"""Attrs to be serialized even if they are derived from other init args.""" """Attrs to be serialized even if they are derived from other init args."""
return { return {
"tool_calls": self.tool_calls, "tool_calls": self.tool_calls,
@ -137,7 +137,7 @@ class AIMessage(BaseMessage):
# Ensure "type" is properly set on all tool call-like dicts. # Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"): if tool_calls := values.get("tool_calls"):
updated: List = [] updated: list = []
for tc in tool_calls: for tc in tool_calls:
updated.append( updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"}) create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
@ -178,7 +178,7 @@ class AIMessage(BaseMessage):
base = super().pretty_repr(html=html) base = super().pretty_repr(html=html)
lines = [] lines = []
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]: def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]:
lines = [ lines = [
f" {tc.get('name', 'Tool')} ({tc.get('id')})", f" {tc.get('name', 'Tool')} ({tc.get('id')})",
f" Call ID: {tc.get('id')}", f" Call ID: {tc.get('id')}",
@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"""The type of the message (used for deserialization). """The type of the message (used for deserialization).
Defaults to "AIMessageChunk".""" Defaults to "AIMessageChunk"."""
tool_call_chunks: List[ToolCallChunk] = [] tool_call_chunks: list[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message.""" """If provided, tool call chunks associated with the message."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Returns: Returns:
@ -232,7 +232,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@property @property
def lc_attributes(self) -> Dict: def lc_attributes(self) -> dict:
"""Attrs to be serialized even if they are derived from other init args.""" """Attrs to be serialized even if they are derived from other init args."""
return { return {
"tool_calls": self.tool_calls, "tool_calls": self.tool_calls,

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator from pydantic import ConfigDict, Field, field_validator
@ -143,7 +144,7 @@ def merge_content(
merged = [merged] + content # type: ignore merged = [merged] + content # type: ignore
elif isinstance(content, list): elif isinstance(content, list):
# If both are lists # If both are lists
merged = merge_lists(cast(List, merged), content) # type: ignore merged = merge_lists(cast(list, merged), content) # type: ignore
# If the first content is a list, and the second content is a string # If the first content is a list, and the second content is a string
else: else:
# If the last element of the first content is a string # If the last element of the first content is a string

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal from typing import Any, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -18,7 +18,7 @@ class ChatMessage(BaseMessage):
"""The type of the message (used during serialization). Defaults to "chat".""" """The type of the message (used during serialization). Defaults to "chat"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]. Default is ["langchain", "schema", "messages"].
""" """
@ -39,7 +39,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
Defaults to "ChatMessageChunk".""" Defaults to "ChatMessageChunk"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]. Default is ["langchain", "schema", "messages"].
""" """

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal from typing import Any, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -26,7 +26,7 @@ class FunctionMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "function".""" """The type of the message (used for serialization). Defaults to "function"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@ -46,7 +46,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
Defaults to "FunctionMessageChunk".""" Defaults to "FunctionMessageChunk"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Union from typing import Any, Literal, Union
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -39,13 +39,13 @@ class HumanMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "human".""" """The type of the message (used for serialization). Defaults to "human"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def __init__( def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None: ) -> None:
"""Pass in content as positional arg. """Pass in content as positional arg.
@ -70,7 +70,7 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
Defaults to "HumanMessageChunk".""" Defaults to "HumanMessageChunk"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal from typing import Any, Literal
from langchain_core.messages.base import BaseMessage from langchain_core.messages.base import BaseMessage
@ -25,7 +25,7 @@ class RemoveMessage(BaseMessage):
return super().__init__("", id=id, **kwargs) return super().__init__("", id=id, **kwargs)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Union from typing import Any, Literal, Union
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -33,13 +33,13 @@ class SystemMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "system".""" """The type of the message (used for serialization). Defaults to "system"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def __init__( def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None: ) -> None:
"""Pass in content as positional arg. """Pass in content as positional arg.
@ -64,7 +64,7 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
Defaults to "SystemMessageChunk".""" Defaults to "SystemMessageChunk"."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Literal, Optional, Union
from uuid import UUID from uuid import UUID
from pydantic import Field, model_validator from pydantic import Field, model_validator
@ -78,7 +78,7 @@ class ToolMessage(BaseMessage):
"""Currently inherited from BaseMessage, but not used.""" """Currently inherited from BaseMessage, but not used."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@ -123,7 +123,7 @@ class ToolMessage(BaseMessage):
return values return values
def __init__( def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None: ) -> None:
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
@ -140,7 +140,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@ -187,7 +187,7 @@ class ToolCall(TypedDict):
name: str name: str
"""The name of the tool to be called.""" """The name of the tool to be called."""
args: Dict[str, Any] args: dict[str, Any]
"""The arguments to the tool call.""" """The arguments to the tool call."""
id: Optional[str] id: Optional[str]
"""An identifier associated with the tool call. """An identifier associated with the tool call.
@ -198,7 +198,7 @@ class ToolCall(TypedDict):
type: NotRequired[Literal["tool_call"]] type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall: def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call") return ToolCall(name=name, args=args, id=id, type="tool_call")
@ -276,8 +276,8 @@ def invalid_tool_call(
def default_tool_parser( def default_tool_parser(
raw_tool_calls: List[dict], raw_tool_calls: list[dict],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]: ) -> tuple[list[ToolCall], list[InvalidToolCall]]:
"""Best-effort parsing of tools.""" """Best-effort parsing of tools."""
tool_calls = [] tool_calls = []
invalid_tool_calls = [] invalid_tool_calls = []
@ -306,7 +306,7 @@ def default_tool_parser(
return tool_calls, invalid_tool_calls return tool_calls, invalid_tool_calls
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]: def default_tool_chunk_parser(raw_tool_calls: list[dict]) -> list[ToolCallChunk]:
"""Best-effort parsing of tool chunks.""" """Best-effort parsing of tool chunks."""
tool_call_chunks = [] tool_call_chunks = []
for tool_call in raw_tool_calls: for tool_call in raw_tool_calls:

View File

@ -11,25 +11,21 @@ from __future__ import annotations
import inspect import inspect
import json import json
from collections.abc import Iterable, Sequence
from functools import partial from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated,
Any, Any,
Callable, Callable,
Dict,
Iterable,
List,
Literal, Literal,
Optional, Optional,
Sequence,
Tuple,
Union, Union,
cast, cast,
overload, overload,
) )
from pydantic import Discriminator, Field, Tag from pydantic import Discriminator, Field, Tag
from typing_extensions import Annotated
from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -198,7 +194,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[ MessageLikeRepresentation = Union[
BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any] BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
] ]

View File

@ -2,12 +2,11 @@ from __future__ import annotations
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, TypeVar, Union from typing import Annotated, Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
import pydantic import pydantic
from pydantic import SkipValidation from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS

View File

@ -3,8 +3,9 @@ from __future__ import annotations
import re import re
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
from typing import AsyncIterator, Iterator, List, TypeVar, Union from collections.abc import AsyncIterator, Iterator
from typing import Optional as Optional from typing import Optional as Optional
from typing import TypeVar, Union
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -29,7 +30,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
yield buffer.popleft() yield buffer.popleft()
class ListOutputParser(BaseTransformOutputParser[List[str]]): class ListOutputParser(BaseTransformOutputParser[list[str]]):
"""Parse the output of an LLM call to a list.""" """Parse the output of an LLM call to a list."""
@property @property

View File

@ -1,6 +1,7 @@
import copy import copy
import json import json
from typing import Any, Dict, List, Optional, Type, Union from types import GenericAlias
from typing import Any, Optional, Union
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -20,7 +21,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True args_only: bool = True
"""Whether to only return the arguments to the function call.""" """Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -72,7 +73,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -166,7 +167,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str key_name: str
"""The name of the key to return.""" """The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -223,7 +224,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
result = parser.parse_result([chat_generation]) result = parser.parse_result([chat_generation])
""" """
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
"""The pydantic schema to parse the output with. """The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to If multiple schemas are provided, then the function name will be used to
@ -232,7 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_schema(cls, values: Dict) -> Any: def validate_schema(cls, values: dict) -> Any:
"""Validate the pydantic schema. """Validate the pydantic schema.
Args: Args:
@ -246,17 +247,19 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
""" """
schema = values["pydantic_schema"] schema = values["pydantic_schema"]
if "args_only" not in values: if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass( values["args_only"] = (
schema, BaseModel isinstance(schema, type)
and not isinstance(schema, GenericAlias)
and issubclass(schema, BaseModel)
) )
elif values["args_only"] and isinstance(schema, Dict): elif values["args_only"] and isinstance(schema, dict):
raise ValueError( raise ValueError(
"If multiple pydantic schemas are provided then args_only should be" "If multiple pydantic schemas are provided then args_only should be"
" False." " False."
) )
return values return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:
@ -292,7 +295,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str attr_name: str
"""The name of the attribute to return.""" """The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.
Args: Args:

View File

@ -1,10 +1,9 @@
import copy import copy
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Dict, List, Optional from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError from pydantic import SkipValidation, ValidationError
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall from langchain_core.messages import AIMessage, InvalidToolCall
@ -17,12 +16,12 @@ from langchain_core.utils.pydantic import TypeBaseModel
def parse_tool_call( def parse_tool_call(
raw_tool_call: Dict[str, Any], raw_tool_call: dict[str, Any],
*, *,
partial: bool = False, partial: bool = False,
strict: bool = False, strict: bool = False,
return_id: bool = True, return_id: bool = True,
) -> Optional[Dict[str, Any]]: ) -> Optional[dict[str, Any]]:
"""Parse a single tool call. """Parse a single tool call.
Args: Args:
@ -69,7 +68,7 @@ def parse_tool_call(
def make_invalid_tool_call( def make_invalid_tool_call(
raw_tool_call: Dict[str, Any], raw_tool_call: dict[str, Any],
error_msg: Optional[str], error_msg: Optional[str],
) -> InvalidToolCall: ) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call. """Create an InvalidToolCall from a raw tool call.
@ -90,12 +89,12 @@ def make_invalid_tool_call(
def parse_tool_calls( def parse_tool_calls(
raw_tool_calls: List[dict], raw_tool_calls: list[dict],
*, *,
partial: bool = False, partial: bool = False,
strict: bool = False, strict: bool = False,
return_id: bool = True, return_id: bool = True,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Parse a list of tool calls. """Parse a list of tool calls.
Args: Args:
@ -111,7 +110,7 @@ def parse_tool_calls(
Raises: Raises:
OutputParserException: If any of the tool calls are not valid JSON. OutputParserException: If any of the tool calls are not valid JSON.
""" """
final_tools: List[Dict[str, Any]] = [] final_tools: list[dict[str, Any]] = []
exceptions = [] exceptions = []
for tool_call in raw_tool_calls: for tool_call in raw_tool_calls:
try: try:
@ -151,7 +150,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
If no tool calls are found, None will be returned. If no tool calls are found, None will be returned.
""" """
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls. """Parse the result of an LLM call to a list of tool calls.
Args: Args:
@ -217,7 +216,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str key_name: str
"""The type of tools to return.""" """The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls. """Parse the result of an LLM call to a list of tool calls.
Args: Args:
@ -254,12 +253,12 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser): class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response.""" """Parse tools from OpenAI response."""
tools: Annotated[List[TypeBaseModel], SkipValidation()] tools: Annotated[list[TypeBaseModel], SkipValidation()]
"""The tools to parse.""" """The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all # TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present. # Pydantic object fields are present.
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects. """Parse the result of an LLM call to a list of Pydantic objects.
Args: Args:

View File

@ -1,9 +1,8 @@
import json import json
from typing import Generic, List, Optional, Type from typing import Annotated, Generic, Optional
import pydantic import pydantic
from pydantic import SkipValidation from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
@ -18,7 +17,7 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model.""" """Parse an output using a pydantic model."""
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse.""" """The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel: def _parse_obj(self, obj: dict) -> TBaseModel:
@ -50,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return OutputParserException(msg, llm_output=json_string) return OutputParserException(msg, llm_output=json_string)
def parse_result( def parse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]: ) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object. """Parse the result of an LLM call to a pydantic object.
@ -108,7 +107,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return "pydantic" return "pydantic"
@property @property
def OutputType(self) -> Type[TBaseModel]: def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model.""" """Return the pydantic model."""
return self.pydantic_object return self.pydantic_object

View File

@ -1,4 +1,3 @@
from typing import List
from typing import Optional as Optional from typing import Optional as Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -13,7 +12,7 @@ class StrOutputParser(BaseTransformOutputParser[str]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output_parser"] return ["langchain", "schema", "output_parser"]

View File

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Iterator,
Optional, Optional,
Union, Union,
) )

View File

@ -1,7 +1,8 @@
import re import re
import xml import xml
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -57,7 +58,7 @@ class _StreamingParser:
_parser = None _parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
self.xml_start_re = re.compile(r"<[a-zA-Z:_]") self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
self.current_path: List[str] = [] self.current_path: list[str] = []
self.current_path_has_children = False self.current_path_has_children = False
self.buffer = "" self.buffer = ""
self.xml_started = False self.xml_started = False
@ -140,7 +141,7 @@ class _StreamingParser:
class XMLOutputParser(BaseTransformOutputParser): class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format.""" """Parse an output using xml format."""
tags: Optional[List[str]] = None tags: Optional[list[str]] = None
encoding_matcher: re.Pattern = re.compile( encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
) )
@ -169,7 +170,7 @@ class XMLOutputParser(BaseTransformOutputParser):
"""Return the format instructions for the XML output.""" """Return the format instructions for the XML output."""
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: def parse(self, text: str) -> dict[str, Union[str, list[Any]]]:
"""Parse the output of an LLM call. """Parse the output of an LLM call.
Args: Args:
@ -234,13 +235,13 @@ class XMLOutputParser(BaseTransformOutputParser):
yield output yield output
streaming_parser.close() streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]: def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
"""Converts xml tree to python dictionary.""" """Converts xml tree to python dictionary."""
if root.text and bool(re.search(r"\S", root.text)): if root.text and bool(re.search(r"\S", root.text)):
# If root text contains any non-whitespace character it # If root text contains any non-whitespace character it
# returns {root.tag: root.text} # returns {root.tag: root.text}
return {root.tag: root.text} return {root.tag: root.text}
result: Dict = {root.tag: []} result: dict = {root.tag: []}
for child in root: for child in root:
if len(child) == 0: if len(child) == 0:
result[root.tag].append({child.tag: child.text}) result[root.tag].append({child.tag: child.text})
@ -253,7 +254,7 @@ class XMLOutputParser(BaseTransformOutputParser):
return "xml" return "xml"
def nested_element(path: List[str], elem: ET.Element) -> Any: def nested_element(path: list[str], elem: ET.Element) -> Any:
"""Get nested element from path. """Get nested element from path.
Args: Args:

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -18,7 +18,7 @@ class ChatResult(BaseModel):
for more information. for more information.
""" """
generations: List[ChatGeneration] generations: list[ChatGeneration]
"""List of the chat generations. """List of the chat generations.
Generations is a list to allow for multiple candidate generations for a single Generations is a list to allow for multiple candidate generations for a single

View File

@ -7,7 +7,8 @@ They can be used to represent text, images, or chat message pieces.
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Literal, Sequence, cast from collections.abc import Sequence
from typing import Literal, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict

View File

@ -1,16 +1,16 @@
from __future__ import annotations from __future__ import annotations
import json import json
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Generic, Generic,
Mapping,
Optional, Optional,
TypeVar, TypeVar,
Union, Union,
@ -39,7 +39,7 @@ FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate( class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
): ):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
@ -50,7 +50,7 @@ class BasePromptTemplate(
"""optional_variables: A list of the names of the variables for placeholder """optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them.""" from the prompt and user need not provide them."""
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006 input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects. """A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings.""" If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None output_parser: Optional[BaseOutputParser] = None
@ -60,7 +60,7 @@ class BasePromptTemplate(
Partial variables populate the template so that you don't need to Partial variables populate the template so that you don't need to
pass them in every time you call the prompt.""" pass them in every time you call the prompt."""
metadata: Optional[Dict[str, Any]] = None # noqa: UP006 metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing.""" """Metadata to be used for tracing."""
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
"""Tags to be used for tracing.""" """Tags to be used for tracing."""

View File

@ -3,15 +3,13 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Annotated,
Any, Any,
List,
Literal, Literal,
Optional, Optional,
Sequence,
Tuple,
Type,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union, Union,
@ -25,7 +23,6 @@ from pydantic import (
SkipValidation, SkipValidation,
model_validator, model_validator,
) )
from typing_extensions import Annotated
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.load import Serializable from langchain_core.load import Serializable
@ -816,9 +813,9 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
MessageLikeRepresentation = Union[ MessageLikeRepresentation = Union[
MessageLike, MessageLike,
Tuple[ tuple[
Union[str, Type], Union[str, type],
Union[str, List[dict], List[object]], Union[str, list[dict], list[object]],
], ],
str, str,
] ]
@ -1017,7 +1014,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
), ),
**kwargs, **kwargs,
} }
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
@ -1083,7 +1080,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
values["partial_variables"][message.variable_name] = [] values["partial_variables"][message.variable_name] = []
optional_variables.add(message.variable_name) optional_variables.add(message.variable_name)
if message.variable_name not in input_types: if message.variable_name not in input_types:
input_types[message.variable_name] = List[AnyMessage] input_types[message.variable_name] = list[AnyMessage]
if "partial_variables" in values: if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"]) input_vars = input_vars - set(values["partial_variables"])
if optional_variables: if optional_variables:

View File

@ -1,7 +1,7 @@
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from pydantic import ConfigDict, model_validator from pydantic import ConfigDict, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -16,7 +16,7 @@ from langchain_core.prompts.string import (
class FewShotPromptWithTemplates(StringPromptTemplate): class FewShotPromptWithTemplates(StringPromptTemplate):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None examples: Optional[list[dict]] = None
"""Examples to format into the prompt. """Examples to format into the prompt.
Either this or example_selector should be provided.""" Either this or example_selector should be provided."""
@ -43,13 +43,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"] return ["langchain", "prompts", "few_shot_with_templates"]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_examples_and_selector(cls, values: Dict) -> Any: def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided.""" """Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None) examples = values.get("examples", None)
example_selector = values.get("example_selector", None) example_selector = values.get("example_selector", None)
@ -93,7 +93,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
extra="forbid", extra="forbid",
) )
def _get_examples(self, **kwargs: Any) -> List[dict]: def _get_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None: if self.examples is not None:
return self.examples return self.examples
elif self.example_selector is not None: elif self.example_selector is not None:
@ -101,7 +101,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
else: else:
raise ValueError raise ValueError
async def _aget_examples(self, **kwargs: Any) -> List[dict]: async def _aget_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None: if self.examples is not None:
return self.examples return self.examples
elif self.example_selector is not None: elif self.example_selector is not None:

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import Any
from pydantic import Field from pydantic import Field
@ -33,7 +33,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
return "image-prompt" return "image-prompt"
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "image"] return ["langchain", "prompts", "image"]

View File

@ -3,7 +3,7 @@
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, Union from typing import Callable, Optional, Union
import yaml import yaml
@ -181,7 +181,7 @@ def _load_prompt_from_file(
return load_prompt_from_config(config) return load_prompt_from_config(config)
def _load_chat_prompt(config: Dict) -> ChatPromptTemplate: def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load chat prompt from config""" """Load chat prompt from config"""
messages = config.pop("messages") messages = config.pop("messages")
@ -194,7 +194,7 @@ def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
return ChatPromptTemplate.from_template(template=template, **config) return ChatPromptTemplate.from_template(template=template, **config)
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = { type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {
"prompt": _load_prompt, "prompt": _load_prompt,
"few_shot": _load_few_shot_prompt, "few_shot": _load_few_shot_prompt,
"chat": _load_chat_prompt, "chat": _load_chat_prompt,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple from typing import Any
from typing import Optional as Optional from typing import Optional as Optional
from pydantic import model_validator from pydantic import model_validator
@ -8,7 +8,7 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate from langchain_core.prompts.chat import BaseChatPromptTemplate
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
return {k: inputs[k] for k in input_variables} return {k: inputs[k] for k in input_variables}
@ -28,17 +28,17 @@ class PipelinePromptTemplate(BasePromptTemplate):
final_prompt: BasePromptTemplate final_prompt: BasePromptTemplate
"""The final prompt that is returned.""" """The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]] pipeline_prompts: list[tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template.""" """A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"] return ["langchain", "prompts", "pipeline"]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_input_variables(cls, values: Dict) -> Any: def get_input_variables(cls, values: dict) -> Any:
"""Get input variables.""" """Get input variables."""
created_variables = set() created_variables = set()
all_variables = set() all_variables = set()

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC from abc import ABC
from string import Formatter from string import Formatter
from typing import Any, Callable, Dict from typing import Any, Callable
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -139,7 +139,7 @@ def mustache_template_vars(
return vars return vars
Defs = Dict[str, "Defs"] Defs = dict[str, "Defs"]
def mustache_schema( def mustache_schema(

View File

@ -1,13 +1,8 @@
from collections.abc import Iterator, Mapping, Sequence
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Iterator,
List,
Mapping,
Optional, Optional,
Sequence,
Type,
Union, Union,
) )
@ -32,16 +27,16 @@ from langchain_core.utils import get_pydantic_field_names
class StructuredPrompt(ChatPromptTemplate): class StructuredPrompt(ChatPromptTemplate):
"""Structured prompt template for a language model.""" """Structured prompt template for a language model."""
schema_: Union[Dict, Type[BaseModel]] schema_: Union[dict, type[BaseModel]]
"""Schema for the structured prompt.""" """Schema for the structured prompt."""
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict) structured_output_kwargs: dict[str, Any] = Field(default_factory=dict)
def __init__( def __init__(
self, self,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
schema_: Optional[Union[Dict, Type[BaseModel]]] = None, schema_: Optional[Union[dict, type[BaseModel]]] = None,
*, *,
structured_output_kwargs: Optional[Dict[str, Any]] = None, structured_output_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
schema_ = schema_ or kwargs.pop("schema") schema_ = schema_ or kwargs.pop("schema")
@ -56,7 +51,7 @@ class StructuredPrompt(ChatPromptTemplate):
) )
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the For example, if the class is `langchain.llms.openai.OpenAI`, then the
@ -68,7 +63,7 @@ class StructuredPrompt(ChatPromptTemplate):
def from_messages_and_schema( def from_messages_and_schema(
cls, cls,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]], schema: Union[dict, type[BaseModel]],
**kwargs: Any, **kwargs: Any,
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
@ -118,7 +113,7 @@ class StructuredPrompt(ChatPromptTemplate):
Callable[[Iterator[Any]], Iterator[Other]], Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
], ],
) -> RunnableSerializable[Dict, Other]: ) -> RunnableSerializable[dict, Other]:
return self.pipe(other) return self.pipe(other)
def pipe( def pipe(
@ -130,7 +125,7 @@ class StructuredPrompt(ChatPromptTemplate):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
], ],
name: Optional[str] = None, name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]: ) -> RunnableSerializable[dict, Other]:
"""Pipe the structured prompt to a language model. """Pipe the structured prompt to a language model.
Args: Args:

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, List, Optional from typing import TYPE_CHECKING, Any, Optional
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -47,7 +47,7 @@ if TYPE_CHECKING:
) )
RetrieverInput = str RetrieverInput = str
RetrieverOutput = List[Document] RetrieverOutput = list[Document]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput] RetrieverOutputLike = Runnable[Any, RetrieverOutput]

View File

@ -6,36 +6,37 @@ import functools
import inspect import inspect
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Awaitable,
Coroutine,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context from contextvars import copy_context
from functools import wraps from functools import wraps
from itertools import groupby, tee from itertools import groupby, tee
from operator import itemgetter from operator import itemgetter
from types import GenericAlias
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable, Callable,
Coroutine,
Dict,
Generic, Generic,
Iterator,
List,
Mapping,
Optional, Optional,
Protocol, Protocol,
Sequence,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
get_type_hints,
overload, overload,
) )
from pydantic import BaseModel, ConfigDict, Field, RootModel from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator from langchain_core._api import beta_decorator
from langchain_core.load.serializable import ( from langchain_core.load.serializable import (
@ -340,7 +341,11 @@ class Runnable(Generic[Input, Output], ABC):
""" """
root_type = self.InputType root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(
@ -408,7 +413,11 @@ class Runnable(Generic[Input, Output], ABC):
""" """
root_type = self.OutputType root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(
@ -771,10 +780,10 @@ class Runnable(Generic[Input, Output], ABC):
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
return cast(List[Output], [invoke(inputs[0], configs[0])]) return cast(list[Output], [invoke(inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor: with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs))) return cast(list[Output], list(executor.map(invoke, inputs, configs)))
@overload @overload
def batch_as_completed( def batch_as_completed(
@ -2024,7 +2033,7 @@ class Runnable(Generic[Input, Output], ABC):
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
if return_exceptions: if return_exceptions:
return cast(List[Output], [e for _ in input]) return cast(list[Output], [e for _ in input])
else: else:
raise raise
else: else:
@ -2036,7 +2045,7 @@ class Runnable(Generic[Input, Output], ABC):
else: else:
run_manager.on_chain_end(out) run_manager.on_chain_end(out)
if return_exceptions or first_exception is None: if return_exceptions or first_exception is None:
return cast(List[Output], output) return cast(list[Output], output)
else: else:
raise first_exception raise first_exception
@ -2099,7 +2108,7 @@ class Runnable(Generic[Input, Output], ABC):
*(run_manager.on_chain_error(e) for run_manager in run_managers) *(run_manager.on_chain_error(e) for run_manager in run_managers)
) )
if return_exceptions: if return_exceptions:
return cast(List[Output], [e for _ in input]) return cast(list[Output], [e for _ in input])
else: else:
raise raise
else: else:
@ -2113,7 +2122,7 @@ class Runnable(Generic[Input, Output], ABC):
coros.append(run_manager.on_chain_end(out)) coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros) await asyncio.gather(*coros)
if return_exceptions or first_exception is None: if return_exceptions or first_exception is None:
return cast(List[Output], output) return cast(list[Output], output)
else: else:
raise first_exception raise first_exception
@ -3171,7 +3180,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for rm in run_managers: for rm in run_managers:
rm.on_chain_error(e) rm.on_chain_error(e)
if return_exceptions: if return_exceptions:
return cast(List[Output], [e for _ in inputs]) return cast(list[Output], [e for _ in inputs])
else: else:
raise raise
else: else:
@ -3183,7 +3192,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
else: else:
run_manager.on_chain_end(out) run_manager.on_chain_end(out)
if return_exceptions or first_exception is None: if return_exceptions or first_exception is None:
return cast(List[Output], inputs) return cast(list[Output], inputs)
else: else:
raise first_exception raise first_exception
@ -3298,7 +3307,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
except BaseException as e: except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions: if return_exceptions:
return cast(List[Output], [e for _ in inputs]) return cast(list[Output], [e for _ in inputs])
else: else:
raise raise
else: else:
@ -3312,7 +3321,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
coros.append(run_manager.on_chain_end(out)) coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros) await asyncio.gather(*coros)
if return_exceptions or first_exception is None: if return_exceptions or first_exception is None:
return cast(List[Output], inputs) return cast(list[Output], inputs)
else: else:
raise first_exception raise first_exception
@ -3420,7 +3429,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
yield chunk yield chunk
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping """Runnable that runs a mapping of Runnables in parallel, and returns a mapping
of their outputs. of their outputs.
@ -4071,7 +4080,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None) module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(
@ -4106,7 +4119,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None) module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(
@ -4369,7 +4386,7 @@ class RunnableLambda(Runnable[Input, Output]):
module = getattr(func, "__module__", None) module = getattr(func, "__module__", None)
return create_model_v2( return create_model_v2(
self.get_name("Input"), self.get_name("Input"),
root=List[Any], root=list[Any],
# To create the schema, we need to provide the module # To create the schema, we need to provide the module
# where the underlying function is defined. # where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately. # This allows pydantic to resolve type annotations appropriately.
@ -4420,7 +4437,11 @@ class RunnableLambda(Runnable[Input, Output]):
func = getattr(self, "func", None) or self.afunc func = getattr(self, "func", None) or self.afunc
module = getattr(func, "__module__", None) module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(
@ -4921,7 +4942,7 @@ class RunnableLambda(Runnable[Input, Output]):
yield chunk yield chunk
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
"""Runnable that delegates calls to another Runnable """Runnable that delegates calls to another Runnable
with each element of the input sequence. with each element of the input sequence.
@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined] return list[self.bound.InputType] # type: ignore[name-defined]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
@ -4946,7 +4967,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return create_model_v2( return create_model_v2(
self.get_name("Input"), self.get_name("Input"),
root=( root=(
List[self.bound.get_input_schema(config)], # type: ignore list[self.bound.get_input_schema(config)], # type: ignore
None, None,
), ),
# create model needs access to appropriate type annotations to be # create model needs access to appropriate type annotations to be
@ -4961,7 +4982,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property @property
def OutputType(self) -> type[list[Output]]: def OutputType(self) -> type[list[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined] return list[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
@ -4969,7 +4990,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config) schema = self.bound.get_output_schema(config)
return create_model_v2( return create_model_v2(
self.get_name("Output"), self.get_name("Output"),
root=List[schema], # type: ignore[valid-type] root=list[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be # create model needs access to appropriate type annotations to be
# able to construct the pydantic model. # able to construct the pydantic model.
# When we create the model, we pass information about the namespace # When we create the model, we pass information about the namespace
@ -5255,7 +5276,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property @property
def InputType(self) -> type[Input]: def InputType(self) -> type[Input]:
return ( return (
cast(Type[Input], self.custom_input_type) cast(type[Input], self.custom_input_type)
if self.custom_input_type is not None if self.custom_input_type is not None
else self.bound.InputType else self.bound.InputType
) )
@ -5263,7 +5284,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property @property
def OutputType(self) -> type[Output]: def OutputType(self) -> type[Output]:
return ( return (
cast(Type[Output], self.custom_output_type) cast(type[Output], self.custom_output_type)
if self.custom_output_type is not None if self.custom_output_type is not None
else self.bound.OutputType else self.bound.OutputType
) )
@ -5336,7 +5357,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]: ) -> list[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], list[RunnableConfig],
[self._merge_configs(conf) for conf in config], [self._merge_configs(conf) for conf in config],
) )
else: else:
@ -5358,7 +5379,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]: ) -> list[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], list[RunnableConfig],
[self._merge_configs(conf) for conf in config], [self._merge_configs(conf) for conf in config],
) )
else: else:
@ -5400,7 +5421,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> Iterator[tuple[int, Union[Output, Exception]]]: ) -> Iterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence): if isinstance(config, Sequence):
configs = cast( configs = cast(
List[RunnableConfig], list[RunnableConfig],
[self._merge_configs(conf) for conf in config], [self._merge_configs(conf) for conf in config],
) )
else: else:
@ -5451,7 +5472,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence): if isinstance(config, Sequence):
configs = cast( configs = cast(
List[RunnableConfig], list[RunnableConfig],
[self._merge_configs(conf) for conf in config], [self._merge_configs(conf) for conf in config],
) )
else: else:

View File

@ -1,15 +1,8 @@
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Awaitable,
Callable, Callable,
Iterator,
List,
Mapping,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
branch.invoke(None) # "goodbye" branch.invoke(None) # "goodbye"
""" """
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output] default: Runnable[Input, Output]
def __init__( def __init__(
self, self,
*branches: Union[ *branches: Union[
Tuple[ tuple[
Union[ Union[
Runnable[Input, bool], Runnable[Input, bool],
Callable[[Input], bool], Callable[[Input], bool],
@ -149,13 +142,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
runnables = ( runnables = (
[self.default] [self.default]
+ [r for _, r in self.branches] + [r for _, r in self.branches]
@ -172,7 +165,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().get_input_schema(config) return super().get_input_schema(config)
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import ( from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX, CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET, CONTEXT_CONFIG_SUFFIX_SET,

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import uuid import uuid
import warnings import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar, copy_context from contextvars import ContextVar, copy_context
@ -10,14 +11,8 @@ from functools import partial
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Generator,
Iterable,
Iterator,
List,
Optional, Optional,
Sequence,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -43,7 +38,7 @@ if TYPE_CHECKING:
else: else:
# Pydantic validates through typed dicts, but # Pydantic validates through typed dicts, but
# the callbacks need forward refs updated # the callbacks need forward refs updated
Callbacks = Optional[Union[List, Any]] Callbacks = Optional[Union[list, Any]]
class EmptyDict(TypedDict, total=False): class EmptyDict(TypedDict, total=False):

View File

@ -3,20 +3,16 @@ from __future__ import annotations
import enum import enum
import threading import threading
from abc import abstractmethod from abc import abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import Mapping as Mapping
from functools import wraps from functools import wraps
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Iterator,
List,
Optional, Optional,
Sequence,
Type,
Union, Union,
cast, cast,
) )
from typing import Mapping as Mapping
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -176,10 +172,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0])]) return cast(list[Output], [invoke(prepared[0], inputs[0])])
with get_executor_for_config(configs[0]) as executor: with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, prepared, inputs))) return cast(list[Output], list(executor.map(invoke, prepared, inputs)))
async def abatch( async def abatch(
self, self,
@ -562,7 +558,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
for v in list(self.alternatives.keys()) + [self.default_key] for v in list(self.alternatives.keys()) + [self.default_key]
), ),
) )
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum) _enums_for_spec[self.which] = cast(type[StrEnum], which_enum)
return get_unique_config_specs( return get_unique_config_specs(
# which alternative # which alternative
[ [
@ -694,7 +690,7 @@ def make_options_spec(
spec.name or spec.id, spec.name or spec.id,
((v, v) for v in list(spec.options.keys())), ((v, v) for v in list(spec.options.keys())),
) )
_enums_for_spec[spec] = cast(Type[StrEnum], enum) _enums_for_spec[spec] = cast(type[StrEnum], enum)
if isinstance(spec, ConfigurableFieldSingleOption): if isinstance(spec, ConfigurableFieldSingleOption):
return ConfigurableFieldSpec( return ConfigurableFieldSpec(
id=spec.id, id=spec.id,

View File

@ -1,19 +1,13 @@
import asyncio import asyncio
import inspect import inspect
import typing import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context from contextvars import copy_context
from functools import wraps from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -96,7 +90,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""The Runnable to run first.""" """The Runnable to run first."""
fallbacks: Sequence[Runnable[Input, Output]] fallbacks: Sequence[Runnable[Input, Output]]
"""A sequence of fallbacks to try.""" """A sequence of fallbacks to try."""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,)
"""The exceptions on which fallbacks should be tried. """The exceptions on which fallbacks should be tried.
Any exception that is not a subclass of these exceptions will be raised immediately. Any exception that is not a subclass of these exceptions will be raised immediately.
@ -112,25 +106,25 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
) )
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> type[Input]:
return self.runnable.InputType return self.runnable.InputType
@property @property
def OutputType(self) -> Type[Output]: def OutputType(self) -> type[Output]:
return self.runnable.OutputType return self.runnable.OutputType
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return self.runnable.get_input_schema(config) return self.runnable.get_input_schema(config)
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return self.runnable.get_output_schema(config) return self.runnable.get_output_schema(config)
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs( return get_unique_config_specs(
spec spec
for step in [self.runnable, *self.fallbacks] for step in [self.runnable, *self.fallbacks]
@ -142,7 +136,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return True return True
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@ -252,12 +246,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
@ -296,9 +290,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for cm, input, config in zip(callback_managers, inputs, configs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
to_return: Dict[int, Any] = {} to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)} run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {} handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:
outputs = runnable.batch( outputs = runnable.batch(
@ -344,12 +338,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> list[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all( if self.exception_key is not None and not all(
@ -378,7 +372,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for config in configs for config in configs
] ]
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start( cm.on_chain_start(
None, None,
@ -392,7 +386,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
to_return = {} to_return = {}
run_again = {i: input for i, input in enumerate(inputs)} run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {} handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None first_to_raise = None
for runnable in self.runnables: for runnable in self.runnables:
outputs = await runnable.abatch( outputs = await runnable.abatch(

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import inspect import inspect
from collections import Counter from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -11,7 +12,6 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence,
TypedDict, TypedDict,
Union, Union,
overload, overload,

View File

@ -3,7 +3,8 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
import math import math
import os import os
from typing import Any, Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any
from langchain_core.runnables.graph import Edge as LangEdge from langchain_core.runnables.graph import Edge as LangEdge

View File

@ -1,7 +1,7 @@
import base64 import base64
import re import re
from dataclasses import asdict from dataclasses import asdict
from typing import Dict, List, Optional from typing import Optional
from langchain_core.runnables.graph import ( from langchain_core.runnables.graph import (
CurveStyle, CurveStyle,
@ -15,8 +15,8 @@ MARKDOWN_SPECIAL_CHARS = "*_`"
def draw_mermaid( def draw_mermaid(
nodes: Dict[str, Node], nodes: dict[str, Node],
edges: List[Edge], edges: list[Edge],
*, *,
first_node: Optional[str] = None, first_node: Optional[str] = None,
last_node: Optional[str] = None, last_node: Optional[str] = None,
@ -87,7 +87,7 @@ def draw_mermaid(
mermaid_graph += f"\t{node_label}\n" mermaid_graph += f"\t{node_label}\n"
# Group edges by their common prefixes # Group edges by their common prefixes
edge_groups: Dict[str, List[Edge]] = {} edge_groups: dict[str, list[Edge]] = {}
for edge in edges: for edge in edges:
src_parts = edge.source.split(":") src_parts = edge.source.split(":")
tgt_parts = edge.target.split(":") tgt_parts = edge.target.split(":")
@ -98,7 +98,7 @@ def draw_mermaid(
seen_subgraphs = set() seen_subgraphs = set()
def add_subgraph(edges: List[Edge], prefix: str) -> None: def add_subgraph(edges: list[Edge], prefix: str) -> None:
nonlocal mermaid_graph nonlocal mermaid_graph
self_loop = len(edges) == 1 and edges[0].source == edges[0].target self_loop = len(edges) == 1 and edges[0].source == edges[0].target
if prefix and not self_loop: if prefix and not self_loop:

View File

@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from collections.abc import Sequence
from types import GenericAlias
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
Optional, Optional,
Sequence,
Union, Union,
) )
@ -31,7 +31,7 @@ if TYPE_CHECKING:
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]] MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
@ -419,7 +419,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
""" """
root_type = self.OutputType root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type return root_type
return create_model_v2( return create_model_v2(

View File

@ -5,15 +5,11 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import threading import threading
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Awaitable,
Callable, Callable,
Dict,
Iterator,
Mapping,
Optional, Optional,
Union, Union,
cast, cast,
@ -349,7 +345,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
_graph_passthrough: RunnablePassthrough = RunnablePassthrough() _graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs. """Runnable that assigns key-value pairs to Dict[str, Any] inputs.
The `RunnableAssign` class takes input dictionaries and, through a The `RunnableAssign` class takes input dictionaries and, through a
@ -564,7 +560,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
if filtered: if filtered:
yield filtered yield filtered
# yield map output # yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result()) yield cast(dict[str, Any], first_map_chunk_future.result())
for chunk in map_output: for chunk in map_output:
yield chunk yield chunk
@ -650,7 +646,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
yield chunk yield chunk
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that picks keys from Dict[str, Any] inputs. """Runnable that picks keys from Dict[str, Any] inputs.
RunnablePick class represents a Runnable that selectively picks keys from a RunnablePick class represents a Runnable that selectively picks keys from a

View File

@ -1,11 +1,7 @@
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
List,
Optional, Optional,
Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -98,7 +94,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
retryable_chain = chain.with_retry() retryable_chain = chain.with_retry()
""" """
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
"""The exception types to retry on. By default all exceptions are retried. """The exception types to retry on. By default all exceptions are retried.
In general you should only retry on exceptions that are likely to be In general you should only retry on exceptions that are likely to be
@ -115,13 +111,13 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
"""The maximum number of attempts to retry the Runnable.""" """The maximum number of attempts to retry the Runnable."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"] return ["langchain", "schema", "runnable"]
@property @property
def _kwargs_retrying(self) -> Dict[str, Any]: def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: Dict[str, Any] = dict() kwargs: dict[str, Any] = dict()
if self.max_attempt_number: if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number) kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
@ -152,10 +148,10 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _patch_config_list( def _patch_config_list(
self, self,
config: List[RunnableConfig], config: list[RunnableConfig],
run_manager: List["T"], run_manager: list["T"],
retry_state: RetryCallState, retry_state: RetryCallState,
) -> List[RunnableConfig]: ) -> list[RunnableConfig]:
return [ return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
] ]
@ -208,17 +204,17 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _batch( def _batch(
self, self,
inputs: List[Input], inputs: list[Input],
run_manager: List["CallbackManagerForChainRun"], run_manager: list["CallbackManagerForChainRun"],
config: List[RunnableConfig], config: list[RunnableConfig],
**kwargs: Any, **kwargs: Any,
) -> List[Union[Output, Exception]]: ) -> list[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]: def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map] return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = [] not_set: list[Output] = []
result = not_set result = not_set
try: try:
for attempt in self._sync_retrying(): for attempt in self._sync_retrying():
@ -250,9 +246,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
except RetryError as e: except RetryError as e:
if result is not_set: if result is not_set:
result = cast(List[Output], [e] * len(inputs)) result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = [] outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs): for idx, _ in enumerate(inputs):
if idx in results_map: if idx in results_map:
outputs.append(results_map[idx]) outputs.append(results_map[idx])
@ -262,29 +258,29 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def batch( def batch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[Output]: ) -> list[Output]:
return self._batch_with_config( return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
) )
async def _abatch( async def _abatch(
self, self,
inputs: List[Input], inputs: list[Input],
run_manager: List["AsyncCallbackManagerForChainRun"], run_manager: list["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig], config: list[RunnableConfig],
**kwargs: Any, **kwargs: Any,
) -> List[Union[Output, Exception]]: ) -> list[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]: def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map] return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = [] not_set: list[Output] = []
result = not_set result = not_set
try: try:
async for attempt in self._async_retrying(): async for attempt in self._async_retrying():
@ -316,9 +312,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
except RetryError as e: except RetryError as e:
if result is not_set: if result is not_set:
result = cast(List[Output], [e] * len(inputs)) result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = [] outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs): for idx, _ in enumerate(inputs):
if idx in results_map: if idx in results_map:
outputs.append(results_map[idx]) outputs.append(results_map[idx])
@ -328,12 +324,12 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async def abatch( async def abatch(
self, self,
inputs: List[Input], inputs: list[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[Output]: ) -> list[Output]:
return await self._abatch_with_config( return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
) )

View File

@ -1,12 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Iterator,
List,
Mapping,
Optional, Optional,
Union, Union,
cast, cast,
@ -154,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
with get_executor_for_config(configs[0]) as executor: with get_executor_for_config(configs[0]) as executor:
return cast( return cast(
List[Output], list[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)), list(executor.map(invoke, runnables, actual_inputs, configs)),
) )

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, Sequence, Union from collections.abc import Sequence
from typing import Any, Literal, Union
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict

View File

@ -6,23 +6,24 @@ import ast
import asyncio import asyncio
import inspect import inspect
import textwrap import textwrap
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Coroutine,
Iterable,
Mapping,
Sequence,
)
from functools import lru_cache from functools import lru_cache
from inspect import signature from inspect import signature
from itertools import groupby from itertools import groupby
from typing import ( from typing import (
Any, Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable, Callable,
Coroutine,
Dict,
Iterable,
Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence,
TypeVar, TypeVar,
Union, Union,
) )
@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]]) return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]): class AddableDict(dict[str, Any]):
""" """
Dictionary that can be added to another dictionary. Dictionary that can be added to another dictionary.
""" """

View File

@ -7,16 +7,11 @@ The primary goal of these storages is to support implementation of caching.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Dict,
Generic, Generic,
Iterator,
List,
Optional, Optional,
Sequence,
Tuple,
TypeVar, TypeVar,
Union, Union,
) )
@ -84,7 +79,7 @@ class BaseStore(Generic[K, V], ABC):
""" """
@abstractmethod @abstractmethod
def mget(self, keys: Sequence[K]) -> List[Optional[V]]: def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Get the values associated with the given keys. """Get the values associated with the given keys.
Args: Args:
@ -95,7 +90,7 @@ class BaseStore(Generic[K, V], ABC):
If a key is not found, the corresponding value will be None. If a key is not found, the corresponding value will be None.
""" """
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Async get the values associated with the given keys. """Async get the values associated with the given keys.
Args: Args:
@ -108,14 +103,14 @@ class BaseStore(Generic[K, V], ABC):
return await run_in_executor(None, self.mget, keys) return await run_in_executor(None, self.mget, keys)
@abstractmethod @abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys. """Set the values for the given keys.
Args: Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
""" """
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Async set the values for the given keys. """Async set the values for the given keys.
Args: Args:
@ -184,9 +179,9 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize an empty store.""" """Initialize an empty store."""
self.store: Dict[str, V] = {} self.store: dict[str, V] = {}
def mget(self, keys: Sequence[str]) -> List[Optional[V]]: def mget(self, keys: Sequence[str]) -> list[Optional[V]]:
"""Get the values associated with the given keys. """Get the values associated with the given keys.
Args: Args:
@ -198,7 +193,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
""" """
return [self.store.get(key) for key in keys] return [self.store.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: async def amget(self, keys: Sequence[str]) -> list[Optional[V]]:
"""Async get the values associated with the given keys. """Async get the values associated with the given keys.
Args: Args:
@ -210,7 +205,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
""" """
return self.mget(keys) return self.mget(keys)
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: def mset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
"""Set the values for the given keys. """Set the values for the given keys.
Args: Args:
@ -222,7 +217,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
for key, value in key_value_pairs: for key, value in key_value_pairs:
self.store[key] = value self.store[key] = value
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: async def amset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
"""Async set the values for the given keys. """Async set the values for the given keys.
Args: Args:

View File

@ -3,8 +3,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -2,10 +2,10 @@
for debugging purposes. for debugging purposes.
""" """
from typing import List, Sequence from collections.abc import Sequence
def _get_sub_deps(packages: Sequence[str]) -> List[str]: def _get_sub_deps(packages: Sequence[str]) -> list[str]:
"""Get any specified sub-dependencies.""" """Get any specified sub-dependencies."""
from importlib import metadata from importlib import metadata

View File

@ -7,15 +7,15 @@ import json
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextvars import copy_context from contextvars import copy_context
from inspect import signature from inspect import signature
from typing import ( from typing import (
Annotated,
Any, Any,
Callable, Callable,
Dict,
Literal, Literal,
Optional, Optional,
Sequence,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -36,7 +36,6 @@ from pydantic import (
) )
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import validate_arguments as validate_arguments_v1 from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import Annotated
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -324,7 +323,7 @@ class ToolException(Exception):
pass pass
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None: def __init_subclass__(cls, **kwargs: Any) -> None:

View File

@ -1,5 +1,5 @@
import inspect import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints from typing import Any, Callable, Literal, Optional, Union, get_type_hints
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
@ -13,7 +13,7 @@ from langchain_core.tools.structured import StructuredTool
def tool( def tool(
*args: Union[str, Callable, Runnable], *args: Union[str, Callable, Runnable],
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type] = None, args_schema: Optional[type] = None,
infer_schema: bool = True, infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
@ -160,7 +160,7 @@ def tool(
coroutine = ainvoke_wrapper coroutine = ainvoke_wrapper
func = invoke_wrapper func = invoke_wrapper
schema: Optional[Type[BaseModel]] = runnable.input_schema schema: Optional[type[BaseModel]] = runnable.input_schema
description = repr(runnable) description = repr(runnable)
elif inspect.iscoroutinefunction(dec_func): elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func coroutine = dec_func
@ -234,8 +234,8 @@ def _get_description_from_runnable(runnable: Runnable) -> str:
def _get_schema_from_runnable_and_arg_types( def _get_schema_from_runnable_and_arg_types(
runnable: Runnable, runnable: Runnable,
name: str, name: str,
arg_types: Optional[Dict[str, Type]] = None, arg_types: Optional[dict[str, type]] = None,
) -> Type[BaseModel]: ) -> type[BaseModel]:
"""Infer args_schema for tool.""" """Infer args_schema for tool."""
if arg_types is None: if arg_types is None:
try: try:
@ -252,11 +252,11 @@ def _get_schema_from_runnable_and_arg_types(
def convert_runnable_to_tool( def convert_runnable_to_tool(
runnable: Runnable, runnable: Runnable,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[type[BaseModel]] = None,
*, *,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None, arg_types: Optional[dict[str, type]] = None,
) -> BaseTool: ) -> BaseTool:
"""Convert a Runnable into a BaseTool. """Convert a Runnable into a BaseTool.

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from inspect import signature from inspect import signature
from typing import Callable, List from typing import Callable
from langchain_core.tools.base import BaseTool from langchain_core.tools.base import BaseTool
ToolsRenderer = Callable[[List[BaseTool]], str] ToolsRenderer = Callable[[list[BaseTool]], str]
def render_text_description(tools: list[BaseTool]) -> str: def render_text_description(tools: list[BaseTool]) -> str:

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable
from inspect import signature from inspect import signature
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
Optional, Optional,
Union, Union,

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import textwrap import textwrap
from collections.abc import Awaitable
from inspect import signature from inspect import signature
from typing import ( from typing import (
Annotated,
Any, Any,
Awaitable,
Callable, Callable,
Literal, Literal,
Optional, Optional,
@ -12,7 +13,6 @@ from typing import (
) )
from pydantic import BaseModel, Field, SkipValidation from pydantic import BaseModel, Field, SkipValidation
from typing_extensions import Annotated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,

View File

@ -1,7 +1,8 @@
"""Internal tracers used for stream_log and astream events implementations.""" """Internal tracers used for stream_log and astream events implementations."""
import abc import abc
from typing import AsyncIterator, Iterator, TypeVar from collections.abc import AsyncIterator, Iterator
from typing import TypeVar
from uuid import UUID from uuid import UUID
T = TypeVar("T") T = TypeVar("T")

View File

@ -5,11 +5,11 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Optional, Optional,
Sequence,
Union, Union,
) )
from uuid import UUID from uuid import UUID

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Generator,
Optional, Optional,
Union, Union,
cast, cast,

View File

@ -6,14 +6,13 @@ import logging
import sys import sys
import traceback import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Coroutine, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Coroutine,
Literal, Literal,
Optional, Optional,
Sequence,
Union, Union,
cast, cast,
) )

View File

@ -5,8 +5,9 @@ from __future__ import annotations
import logging import logging
import threading import threading
import weakref import weakref
from collections.abc import Sequence
from concurrent.futures import Future, ThreadPoolExecutor, wait from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, List, Optional, Sequence, Union, cast from typing import Any, Optional, Union, cast
from uuid import UUID from uuid import UUID
import langsmith import langsmith
@ -156,7 +157,7 @@ class EvaluatorCallbackHandler(BaseTracer):
if isinstance(results, EvaluationResult): if isinstance(results, EvaluationResult):
results_ = [results] results_ = [results]
elif isinstance(results, dict) and "results" in results: elif isinstance(results, dict) and "results" in results:
results_ = cast(List[EvaluationResult], results["results"]) results_ = cast(list[EvaluationResult], results["results"])
else: else:
raise TypeError( raise TypeError(
f"Invalid evaluation result type {type(results)}." f"Invalid evaluation result type {type(results)}."

View File

@ -4,14 +4,11 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Iterator,
List,
Optional, Optional,
Sequence,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -459,7 +456,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
output: Union[dict, BaseMessage] = {} output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model": if run_info["run_type"] == "chat_model":
generations = cast(List[List[ChatGenerationChunk]], response.generations) generations = cast(list[list[ChatGenerationChunk]], response.generations)
for gen in generations: for gen in generations:
if output != {}: if output != {}:
break break
@ -469,7 +466,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
event = "on_chat_model_end" event = "on_chat_model_end"
elif run_info["run_type"] == "llm": elif run_info["run_type"] == "llm":
generations = cast(List[List[GenerationChunk]], response.generations) generations = cast(list[list[GenerationChunk]], response.generations)
output = { output = {
"generations": [ "generations": [
[ [

View File

@ -4,13 +4,11 @@ import asyncio
import copy import copy
import threading import threading
from collections import defaultdict from collections import defaultdict
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Iterator,
Literal, Literal,
Optional, Optional,
Sequence,
TypeVar, TypeVar,
Union, Union,
overload, overload,

View File

@ -11,7 +11,8 @@ used in the code.
import asyncio import asyncio
from asyncio import AbstractEventLoop, Queue from asyncio import AbstractEventLoop, Queue
from typing import AsyncIterator, Generic, TypeVar from collections.abc import AsyncIterator
from typing import Generic, TypeVar
T = TypeVar("T") T = TypeVar("T")

View File

@ -1,4 +1,5 @@
from typing import Awaitable, Callable, Optional, Union from collections.abc import Awaitable
from typing import Callable, Optional, Union
from uuid import UUID from uuid import UUID
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (

View File

@ -1,6 +1,6 @@
"""A tracer that collects all nested runs in a list.""" """A tracer that collects all nested runs in a list."""
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
@ -38,7 +38,7 @@ class RunCollectorCallbackHandler(BaseTracer):
self.example_id = ( self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id UUID(example_id) if isinstance(example_id, str) else example_id
) )
self.traced_runs: List[Run] = [] self.traced_runs: list[Run] = []
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
""" """

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, Callable, List from typing import Any, Callable
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@ -54,7 +54,7 @@ class FunctionCallbackHandler(BaseTracer):
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
pass pass
def get_parents(self, run: Run) -> List[Run]: def get_parents(self, run: Run) -> list[Run]:
"""Get the parents of a run. """Get the parents of a run.
Args: Args:

View File

@ -5,23 +5,20 @@ MIT License
""" """
from collections import deque from collections import deque
from contextlib import AbstractAsyncContextManager from collections.abc import (
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
AsyncGenerator, AsyncGenerator,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable,
Deque,
Generic,
Iterator, Iterator,
List, )
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import (
Any,
Callable,
Generic,
Optional, Optional,
Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -95,10 +92,10 @@ class NoLock:
async def tee_peer( async def tee_peer(
iterator: AsyncIterator[T], iterator: AsyncIterator[T],
# the buffer specific to this peer # the buffer specific to this peer
buffer: Deque[T], buffer: deque[T],
# the buffers of all peers, including our own # the buffers of all peers, including our own
peers: List[Deque[T]], peers: list[deque[T]],
lock: AsyncContextManager[Any], lock: AbstractAsyncContextManager[Any],
) -> AsyncGenerator[T, None]: ) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`. """An individual iterator of a :py:func:`~.tee`.
@ -191,10 +188,10 @@ class Tee(Generic[T]):
iterable: AsyncIterator[T], iterable: AsyncIterator[T],
n: int = 2, n: int = 2,
*, *,
lock: Optional[AsyncContextManager[Any]] = None, lock: Optional[AbstractAsyncContextManager[Any]] = None,
): ):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)] self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple( self._children = tuple(
tee_peer( tee_peer(
iterator=self._iterator, iterator=self._iterator,
@ -212,11 +209,11 @@ class Tee(Generic[T]):
def __getitem__(self, item: int) -> AsyncIterator[T]: ... def __getitem__(self, item: int) -> AsyncIterator[T]: ...
@overload @overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ... def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
def __getitem__( def __getitem__(
self, item: Union[int, slice] self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]: ) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
return self._children[item] return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]: def __iter__(self) -> Iterator[AsyncIterator[T]]:
@ -267,7 +264,7 @@ class aclosing(AbstractAsyncContextManager):
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> None: ) -> None:
@ -277,7 +274,7 @@ class aclosing(AbstractAsyncContextManager):
async def abatch_iterate( async def abatch_iterate(
size: int, iterable: AsyncIterable[T] size: int, iterable: AsyncIterable[T]
) -> AsyncIterator[List[T]]: ) -> AsyncIterator[list[T]]:
"""Utility batching function for async iterables. """Utility batching function for async iterables.
Args: Args:
@ -287,7 +284,7 @@ async def abatch_iterate(
Returns: Returns:
An async iterator over the batches. An async iterator over the batches.
""" """
batch: List[T] = [] batch: list[T] = []
async for element in iterable: async for element in iterable:
if len(batch) < size: if len(batch) < size:
batch.append(element) batch.append(element)

View File

@ -1,7 +1,8 @@
"""Utilities for formatting strings.""" """Utilities for formatting strings."""
from collections.abc import Mapping, Sequence
from string import Formatter from string import Formatter
from typing import Any, List, Mapping, Sequence from typing import Any
class StrictFormatter(Formatter): class StrictFormatter(Formatter):
@ -31,7 +32,7 @@ class StrictFormatter(Formatter):
return super().vformat(format_string, args, kwargs) return super().vformat(format_string, args, kwargs)
def validate_input_variables( def validate_input_variables(
self, format_string: str, input_variables: List[str] self, format_string: str, input_variables: list[str]
) -> None: ) -> None:
"""Check that all input variables are used in the format string. """Check that all input variables are used in the format string.

View File

@ -10,21 +10,17 @@ import typing
import uuid import uuid
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated,
Any, Any,
Callable, Callable,
Dict,
List,
Literal, Literal,
Optional, Optional,
Set,
Tuple,
Type,
Union, Union,
cast, cast,
) )
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
@ -201,7 +197,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
model = cast( model = cast(
Type[BaseModel], type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited), _convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
) )
return convert_pydantic_to_openai_function(model) # type: ignore return convert_pydantic_to_openai_function(model) # type: ignore
@ -383,15 +379,15 @@ def convert_to_openai_function(
"parameters": function, "parameters": function,
} }
elif isinstance(function, type) and is_basemodel_subclass(function): elif isinstance(function, type) and is_basemodel_subclass(function):
oai_function = cast(Dict, convert_pydantic_to_openai_function(function)) oai_function = cast(dict, convert_pydantic_to_openai_function(function))
elif is_typeddict(function): elif is_typeddict(function):
oai_function = cast( oai_function = cast(
Dict, _convert_typed_dict_to_openai_function(cast(Type, function)) dict, _convert_typed_dict_to_openai_function(cast(type, function))
) )
elif isinstance(function, BaseTool): elif isinstance(function, BaseTool):
oai_function = cast(Dict, format_tool_to_openai_function(function)) oai_function = cast(dict, format_tool_to_openai_function(function))
elif callable(function): elif callable(function):
oai_function = cast(Dict, convert_python_function_to_openai_function(function)) oai_function = cast(dict, convert_python_function_to_openai_function(function))
else: else:
raise ValueError( raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in" f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
@ -598,17 +594,17 @@ def _py_38_safe_origin(origin: type) -> type:
) )
origin_map: dict[type, Any] = { origin_map: dict[type, Any] = {
dict: Dict, dict: dict,
list: List, list: list,
tuple: Tuple, tuple: tuple,
set: Set, set: set,
collections.abc.Iterable: typing.Iterable, collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping, collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence, collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping, collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map, **origin_union_type_map,
} }
return cast(Type, origin_map.get(origin, origin)) return cast(type, origin_map.get(origin, origin))
def _recursive_set_additional_properties_false( def _recursive_set_additional_properties_false(

View File

@ -1,6 +1,7 @@
import logging import logging
import re import re
from typing import List, Optional, Sequence, Union from collections.abc import Sequence
from typing import Optional, Union
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ DEFAULT_LINK_REGEX = (
def find_all_links( def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]: ) -> list[str]:
"""Extract all links from a raw HTML string. """Extract all links from a raw HTML string.
Args: Args:
@ -56,7 +57,7 @@ def extract_sub_links(
prevent_outside: bool = True, prevent_outside: bool = True,
exclude_prefixes: Sequence[str] = (), exclude_prefixes: Sequence[str] = (),
continue_on_failure: bool = False, continue_on_failure: bool = False,
) -> List[str]: ) -> list[str]:
"""Extract all links from a raw HTML string and convert into absolute paths. """Extract all links from a raw HTML string and convert into absolute paths.
Args: Args:

View File

@ -1,6 +1,6 @@
"""Handle chained inputs.""" """Handle chained inputs."""
from typing import Dict, List, Optional, TextIO from typing import Optional, TextIO
_TEXT_COLOR_MAPPING = { _TEXT_COLOR_MAPPING = {
"blue": "36;1", "blue": "36;1",
@ -12,8 +12,8 @@ _TEXT_COLOR_MAPPING = {
def get_color_mapping( def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None items: list[str], excluded_colors: Optional[list] = None
) -> Dict[str, str]: ) -> dict[str, str]:
"""Get mapping for items to a support color. """Get mapping for items to a support color.
Args: Args:

View File

@ -1,16 +1,11 @@
from collections import deque from collections import deque
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice from itertools import islice
from typing import ( from typing import (
Any, Any,
ContextManager,
Deque,
Generator,
Generic, Generic,
Iterable,
Iterator,
List,
Optional, Optional,
Tuple,
TypeVar, TypeVar,
Union, Union,
overload, overload,
@ -34,10 +29,10 @@ class NoLock:
def tee_peer( def tee_peer(
iterator: Iterator[T], iterator: Iterator[T],
# the buffer specific to this peer # the buffer specific to this peer
buffer: Deque[T], buffer: deque[T],
# the buffers of all peers, including our own # the buffers of all peers, including our own
peers: List[Deque[T]], peers: list[deque[T]],
lock: ContextManager[Any], lock: AbstractContextManager[Any],
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`. """An individual iterator of a :py:func:`~.tee`.
@ -130,7 +125,7 @@ class Tee(Generic[T]):
iterable: Iterator[T], iterable: Iterator[T],
n: int = 2, n: int = 2,
*, *,
lock: Optional[ContextManager[Any]] = None, lock: Optional[AbstractContextManager[Any]] = None,
): ):
"""Create a new ``tee``. """Create a new ``tee``.
@ -141,7 +136,7 @@ class Tee(Generic[T]):
Defaults to None. Defaults to None.
""" """
self._iterator = iter(iterable) self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)] self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple( self._children = tuple(
tee_peer( tee_peer(
iterator=self._iterator, iterator=self._iterator,
@ -159,11 +154,11 @@ class Tee(Generic[T]):
def __getitem__(self, item: int) -> Iterator[T]: ... def __getitem__(self, item: int) -> Iterator[T]: ...
@overload @overload
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: ... def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ...
def __getitem__( def __getitem__(
self, item: Union[int, slice] self, item: Union[int, slice]
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]: ) -> Union[Iterator[T], tuple[Iterator[T], ...]]:
return self._children[item] return self._children[item]
def __iter__(self) -> Iterator[Iterator[T]]: def __iter__(self) -> Iterator[Iterator[T]]:
@ -185,7 +180,7 @@ class Tee(Generic[T]):
safetee = Tee safetee = Tee
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]: def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]:
"""Utility batching function. """Utility batching function.
Args: Args:

Some files were not shown because too many files have changed in this diff Show More