core: Put Python version as a project requirement so it is considered by ruff (#26608)

Ruff doesn't know about the python version in
`[tool.poetry.dependencies]`. It can get it from
`project.requires-python`.

Notes:
* poetry seems to have issues getting the python constraints from
`requires-python` and using `python` in per dependency constraints. So I
had to duplicate the info. I will open an issue on poetry.
* `inspect.isclass()` doesn't work correctly with `GenericAlias`
(`list[...]`, `dict[..., ...]`) on Python <3.11 so I added some `not
isinstance(type, GenericAlias)` checks:

Python 3.11
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
False
```

Python 3.9
```pycon
>>> import inspect
>>> inspect.isclass(list)
True
>>> inspect.isclass(list[str])
True
```

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-18 16:37:57 +02:00 committed by GitHub
parent 0f07cf61da
commit a47b332841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
162 changed files with 920 additions and 1002 deletions

View File

@ -14,7 +14,8 @@ import contextlib
import functools
import inspect
import warnings
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
from collections.abc import Generator
from typing import Any, Callable, TypeVar, Union, cast
from langchain_core._api.internal import is_caller_internal
@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
# PUBLIC API
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
T = TypeVar("T", bound=Union[Callable[..., Any], type])
def beta(

View File

@ -14,11 +14,10 @@ import contextlib
import functools
import inspect
import warnings
from collections.abc import Generator
from typing import (
Any,
Callable,
Generator,
Type,
TypeVar,
Union,
cast,
@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
# Last Any should be FieldInfoV1 but this leads to circular imports
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
def _validate_deprecation_params(
@ -262,7 +261,7 @@ def deprecated(
if not _obj_type:
_obj_type = "attribute"
wrapped = None
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
old_doc = obj.__doc__
class _deprecated_property(property):
@ -304,7 +303,7 @@ def deprecated(
)
else:
_name = _name or cast(Union[Type, Callable], obj).__qualname__
_name = _name or cast(Union[type, Callable], obj).__qualname__
if not _obj_type:
# edge case: when a function is within another function
# within a test, this will call it a "method" not a "function"

View File

@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
from __future__ import annotations
import json
from typing import Any, Literal, Sequence, Union
from collections.abc import Sequence
from typing import Any, Literal, Union
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (

View File

@ -1,19 +1,13 @@
import asyncio
import threading
from collections import defaultdict
from collections.abc import Awaitable, Mapping, Sequence
from functools import partial
from itertools import groupby
from typing import (
Any,
Awaitable,
Callable,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
T = TypeVar("T")
Values = Dict[Union[asyncio.Event, threading.Event], Any]
Values = dict[Union[asyncio.Event, threading.Event], Any]
CONTEXT_CONFIG_PREFIX = "__context__/"
CONTEXT_CONFIG_SUFFIX_GET = "/get"
CONTEXT_CONFIG_SUFFIX_SET = "/set"
@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
def _config_with_context(
config: RunnableConfig,
steps: List[Runnable],
steps: list[Runnable],
setter: Callable,
getter: Callable,
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
event_cls: Union[type[threading.Event], type[asyncio.Event]],
) -> RunnableConfig:
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
return config
@ -99,10 +93,10 @@ def _config_with_context(
}
values: Values = {}
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
event_cls
)
context_funcs: Dict[str, Callable[[], Any]] = {}
context_funcs: dict[str, Callable[[], Any]] = {}
for key, group in grouped_by_key.items():
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
@ -129,7 +123,7 @@ def _config_with_context(
def aconfig_with_context(
config: RunnableConfig,
steps: List[Runnable],
steps: list[Runnable],
) -> RunnableConfig:
"""Asynchronously patch a runnable config with context getters and setters.
@ -145,7 +139,7 @@ def aconfig_with_context(
def config_with_context(
config: RunnableConfig,
steps: List[Runnable],
steps: list[Runnable],
) -> RunnableConfig:
"""Patch a runnable config with context getters and setters.
@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
prefix: str = ""
key: Union[str, List[str]]
key: Union[str, list[str]]
def __str__(self) -> str:
return f"ContextGet({_print_keys(self.key)})"
@property
def ids(self) -> List[str]:
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
id=id_,
@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
@property
def ids(self) -> List[str]:
def ids(self) -> list[str]:
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
mapper_config_specs = [
s
for mapper in self.keys.values()
@ -364,7 +358,7 @@ class Context:
return PrefixContext(prefix=scope)
@staticmethod
def getter(key: Union[str, List[str]], /) -> ContextGet:
def getter(key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key)
@staticmethod
@ -385,7 +379,7 @@ class PrefixContext:
def __init__(self, prefix: str = ""):
self.prefix = prefix
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
return ContextGet(key=key, prefix=self.prefix)
def setter(

View File

@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor

View File

@ -3,7 +3,8 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import UUID
from tenacity import RetryCallState
@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]

View File

@ -5,19 +5,15 @@ import functools
import logging
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import copy_context
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Coroutine,
Generator,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
@ -2352,7 +2348,7 @@ def _configure(
and handler_class is not None
)
if var.get() is not None or create_one:
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
if handler_class is None:
if not any(
handler is var_handler # direct pointer comparison

View File

@ -18,7 +18,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence, Union
from collections.abc import Sequence
from typing import Union
from pydantic import BaseModel, Field

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Iterator, List
from collections.abc import Iterator
from langchain_core.chat_sessions import ChatSession
@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
An iterator of chat sessions.
"""
def load(self) -> List[ChatSession]:
def load(self) -> list[ChatSession]:
"""Eagerly load the chat sessions into memory.
Returns:

View File

@ -1,6 +1,7 @@
"""**Chat Sessions** are a collection of messages and function calls."""
from typing import Sequence, TypedDict
from collections.abc import Sequence
from typing import TypedDict
from langchain_core.messages import BaseMessage

View File

@ -3,7 +3,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Optional
from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor

View File

@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable
from collections.abc import Iterable
# Re-export Blob and PathLike for backwards compatibility
from langchain_core.documents.base import Blob as Blob

View File

@ -1,7 +1,8 @@
import datetime
import json
import uuid
from typing import Any, Callable, Iterator, Optional, Sequence, Union
from collections.abc import Iterator, Sequence
from typing import Any, Callable, Optional, Union
from langsmith import Client as LangSmithClient

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import contextlib
import mimetypes
from collections.abc import Generator
from io import BufferedReader, BytesIO
from pathlib import PurePath
from typing import Any, Generator, Literal, Optional, Union, cast
from typing import Any, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Sequence
from collections.abc import Sequence
from typing import Optional
from pydantic import BaseModel

View File

@ -1,7 +1,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from langchain_core.runnables.config import run_in_executor

View File

@ -1,7 +1,6 @@
"""**Embeddings** interface."""
from abc import ABC, abstractmethod
from typing import List
from langchain_core.runnables.config import run_in_executor
@ -35,7 +34,7 @@ class Embeddings(ABC):
"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs.
Args:
@ -46,7 +45,7 @@ class Embeddings(ABC):
"""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text.
Args:
@ -56,7 +55,7 @@ class Embeddings(ABC):
Embedding.
"""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs.
Args:
@ -67,7 +66,7 @@ class Embeddings(ABC):
"""
return await run_in_executor(None, self.embed_documents, texts)
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text.
Args:

View File

@ -2,7 +2,6 @@
# Please do not add additional fake embedding model implementations here.
import hashlib
from typing import List
from pydantic import BaseModel
@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
size: int
"""The size of the embedding vector."""
def _get_embedding(self) -> List[float]:
def _get_embedding(self) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
return self._get_embedding()
@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
size: int
"""The size of the embedding vector."""
def _get_embedding(self, seed: int) -> List[float]:
def _get_embedding(self, seed: int) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]
# set the seed for the random generator
@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
"""Get a seed for the random generator, using the hash of the text."""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
return self._get_embedding(seed=self._get_seed(text))

View File

@ -1,7 +1,7 @@
"""Interface for selecting examples to include in prompts."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any
from langchain_core.runnables import run_in_executor
@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
"""Interface for selecting examples to include in prompts."""
@abstractmethod
def add_example(self, example: Dict[str, str]) -> Any:
def add_example(self, example: dict[str, str]) -> Any:
"""Add new example to store.
Args:
example: A dictionary with keys as input variables
and values as their values."""
async def aadd_example(self, example: Dict[str, str]) -> Any:
async def aadd_example(self, example: dict[str, str]) -> Any:
"""Async add new example to store.
Args:
@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
return await run_in_executor(None, self.add_example, example)
@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select which examples to use based on the inputs.
Args:
input_variables: A dictionary with keys as input variables
and values as their values."""
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the inputs.
Args:

View File

@ -1,7 +1,7 @@
"""Select examples based on length."""
import re
from typing import Callable, Dict, List
from typing import Callable
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
"""Select examples based on length."""
examples: List[dict]
examples: list[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut."""
example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
"""Length of each example."""
def add_example(self, example: Dict[str, str]) -> None:
def add_example(self, example: dict[str, str]) -> None:
"""Add new example to list.
Args:
@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
string_example = self.example_prompt.format(**example)
self.example_text_lengths.append(self.get_text_length(string_example))
async def aadd_example(self, example: Dict[str, str]) -> None:
async def aadd_example(self, example: dict[str, str]) -> None:
"""Async add new example to list.
Args:
@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
return self
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Select which examples to use based on the input lengths.
Args:
@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
i += 1
return examples
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the input lengths.
Args:

View File

@ -1,13 +1,10 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
from typing import (
Any,
AsyncIterable,
ClassVar,
Collection,
Iterable,
Iterator,
Optional,
)

View File

@ -1,5 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Iterable, List, Literal, Union
from typing import Literal, Union
from langchain_core._api import beta
from langchain_core.documents import Document
@ -41,7 +42,7 @@ METADATA_LINKS_KEY = "links"
@beta()
def get_links(doc: Document) -> List[Link]:
def get_links(doc: Document) -> list[Link]:
"""Get the links from a document.
Args:

View File

@ -5,17 +5,13 @@ from __future__ import annotations
import hashlib
import json
import uuid
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
from itertools import islice
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Iterable,
Iterator,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
Union,

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import abc
import time
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, TypedDict
from collections.abc import Sequence
from typing import Any, Optional, TypedDict
from langchain_core._api import beta
from langchain_core.documents import Document

View File

@ -1,5 +1,6 @@
import uuid
from typing import Any, Dict, List, Optional, Sequence, cast
from collections.abc import Sequence
from typing import Any, Optional, cast
from pydantic import Field
@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex):
.. versionadded:: 0.2.29
"""
store: Dict[str, Document] = Field(default_factory=dict)
store: dict[str, Document] = Field(default_factory=dict)
top_k: int = 4
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
@ -43,7 +44,7 @@ class InMemoryDocumentIndex(DocumentIndex):
return UpsertResponse(succeeded=ok_ids, failed=[])
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
"""Delete by ID."""
if ids is None:
raise ValueError("IDs must be provided for deletion")
@ -59,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex):
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
)
def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]:
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
"""Get by ids."""
found_documents = []
@ -71,7 +72,7 @@ class InMemoryDocumentIndex(DocumentIndex):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
) -> list[Document]:
counts_by_doc = []
for document in self.store.values():

View File

@ -1,16 +1,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from functools import lru_cache
from collections.abc import Mapping, Sequence
from functools import cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Literal,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
)
@ -52,7 +50,7 @@ class LangSmithParams(TypedDict, total=False):
"""Stop words for generation."""
@lru_cache(maxsize=None) # Cache the tokenizer
@cache # Cache the tokenizer
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.
@ -158,7 +156,7 @@ class BaseLanguageModel(
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
list[AnyMessage],
]
@abstractmethod

View File

@ -3,21 +3,19 @@ from __future__ import annotations
import asyncio
import inspect
import json
import typing
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from functools import cached_property
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
Literal,
Optional,
Sequence,
Union,
cast,
)
@ -1121,18 +1119,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
def with_structured_output(
self,
schema: Union[Dict, type], # noqa: UP006
schema: Union[typing.Dict, type], # noqa: UP006
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
"""Model wrapper that returns outputs formatted to match the given schema.
Args:

View File

@ -1,6 +1,7 @@
import asyncio
import time
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -14,7 +15,7 @@ from langchain_core.runnables import RunnableConfig
class FakeListLLM(LLM):
"""Fake LLM for testing purposes."""
responses: List[str]
responses: list[str]
"""List of responses to return in order."""
# This parameter should be removed from FakeListLLM since
# it's only used by sub-classes.
@ -37,7 +38,7 @@ class FakeListLLM(LLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -52,7 +53,7 @@ class FakeListLLM(LLM):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -90,7 +91,7 @@ class FakeStreamingListLLM(FakeListLLM):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
result = self.invoke(input, config)
@ -110,7 +111,7 @@ class FakeStreamingListLLM(FakeListLLM):
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
result = await self.ainvoke(input, config)

View File

@ -3,7 +3,8 @@
import asyncio
import re
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, Union, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -17,7 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[BaseMessage]
responses: list[BaseMessage]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
"""Sleep time in seconds between responses."""
@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -51,7 +52,7 @@ class FakeListChatModelError(Exception):
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[str]
responses: list[str]
"""List of responses to **cycle** through in order."""
sleep: Optional[float] = None
i: int = 0
@ -65,8 +66,8 @@ class FakeListChatModel(SimpleChatModel):
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -80,8 +81,8 @@ class FakeListChatModel(SimpleChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -103,8 +104,8 @@ class FakeListChatModel(SimpleChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
messages: list[BaseMessage],
stop: Union[list[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
@ -124,7 +125,7 @@ class FakeListChatModel(SimpleChatModel):
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}
@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel):
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@ -142,8 +143,8 @@ class FakeChatModel(SimpleChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -157,7 +158,7 @@ class FakeChatModel(SimpleChatModel):
return "fake-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
return {"key": "fake"}
@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@ -202,8 +203,8 @@ class GenericFakeChatModel(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
@ -231,7 +232,7 @@ class GenericFakeChatModel(BaseChatModel):
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(List[str], re.split(r"(\s)", content))
content_chunks = cast(list[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(
@ -249,7 +250,7 @@ class GenericFakeChatModel(BaseChatModel):
for fkey, fvalue in value.items():
if isinstance(fvalue, str):
# Break function call by `,`
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:

View File

@ -10,16 +10,12 @@ import logging
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from pathlib import Path
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
@ -448,7 +444,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
return cast(list[str], [e for _ in inputs])
else:
raise e
else:
@ -494,7 +490,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
return cast(list[str], [e for _ in inputs])
else:
raise e
else:
@ -883,13 +879,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
callbacks = cast(list[Callbacks], callbacks)
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
list[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
CallbackManager.configure(
@ -910,9 +906,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
cast(list[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
cast(dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
@ -1116,13 +1112,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
callbacks = cast(list[Callbacks], callbacks)
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
list[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
AsyncCallbackManager.configure(
@ -1143,9 +1139,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
cast(list[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
cast(dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)

View File

@ -1,7 +1,7 @@
import importlib
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from langchain_core._api import beta
from langchain_core.load.mapping import (
@ -34,11 +34,11 @@ class Reviver:
def __init__(
self,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[
Dict[Tuple[str, ...], Tuple[str, ...]]
dict[tuple[str, ...], tuple[str, ...]]
] = None,
) -> None:
"""Initialize the reviver.
@ -73,7 +73,7 @@ class Reviver:
else ALL_SERIALIZABLE_MAPPINGS
)
def __call__(self, value: Dict[str, Any]) -> Any:
def __call__(self, value: dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
@ -154,10 +154,10 @@ class Reviver:
def loads(
text: str,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
) -> Any:
"""Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`.
@ -190,10 +190,10 @@ def loads(
def load(
obj: Any,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
secrets_map: Optional[dict[str, str]] = None,
valid_namespaces: Optional[list[str]] = None,
secrets_from_env: bool = True,
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.

View File

@ -18,11 +18,9 @@ The mapping allows us to deserialize an AIMessage created with an older
version of LangChain where the code was in a different location.
"""
from typing import Dict, Tuple
# First value is the value that it is serialized as
# Second value is the path to load it from
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain", "schema", "messages", "AIMessage"): (
"langchain_core",
"messages",
@ -535,7 +533,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
# Needed for backwards compatibility for old versions of LangChain where things
# Were in different place
_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
_OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain", "schema", "AIMessage"): (
"langchain_core",
"messages",
@ -583,7 +581,7 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
# Needed for backwards compatibility for a few versions where we serialized
# with langchain_core paths.
OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain_core", "messages", "ai", "AIMessage"): (
"langchain_core",
"messages",
@ -937,7 +935,7 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
),
}
_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
_JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
("langchain_core", "messages", "AIMessage"): (
"langchain_core",
"messages",

View File

@ -1,8 +1,6 @@
from abc import ABC
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TypedDict,
@ -25,9 +23,9 @@ class BaseSerialized(TypedDict):
"""
lc: int
id: List[str]
id: list[str]
name: NotRequired[str]
graph: NotRequired[Dict[str, Any]]
graph: NotRequired[dict[str, Any]]
class SerializedConstructor(BaseSerialized):
@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized):
"""
type: Literal["constructor"]
kwargs: Dict[str, Any]
kwargs: dict[str, Any]
class SerializedSecret(BaseSerialized):
@ -125,7 +123,7 @@ class Serializable(BaseModel, ABC):
return False
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the
@ -134,7 +132,7 @@ class Serializable(BaseModel, ABC):
return cls.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
"""A map of constructor argument names to secret ids.
For example,
@ -143,7 +141,7 @@ class Serializable(BaseModel, ABC):
return dict()
@property
def lc_attributes(self) -> Dict:
def lc_attributes(self) -> dict:
"""List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor.
@ -152,7 +150,7 @@ class Serializable(BaseModel, ABC):
return {}
@classmethod
def lc_id(cls) -> List[str]:
def lc_id(cls) -> list[str]:
"""A unique identifier for this class for serialization purposes.
The unique identifier is a list of strings that describes the path
@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
root: dict[Any, Any], secrets_map: dict[str, str]
) -> dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
Returns:
SerializedNotImplemented
"""
_id: List[str] = []
_id: list[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from pydantic import model_validator
from typing_extensions import Self, TypedDict
@ -69,9 +69,9 @@ class AIMessage(BaseMessage):
At the moment, this is ignored by most models. Usage is discouraged.
"""
tool_calls: List[ToolCall] = []
tool_calls: list[ToolCall] = []
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""
usage_metadata: Optional[UsageMetadata] = None
"""If provided, usage metadata for a message, such as token counts.
@ -83,7 +83,7 @@ class AIMessage(BaseMessage):
"""The type of the message (used for deserialization). Defaults to "ai"."""
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg.
@ -94,7 +94,7 @@ class AIMessage(BaseMessage):
super().__init__(content=content, **kwargs)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns:
@ -104,7 +104,7 @@ class AIMessage(BaseMessage):
return ["langchain", "schema", "messages"]
@property
def lc_attributes(self) -> Dict:
def lc_attributes(self) -> dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
@ -137,7 +137,7 @@ class AIMessage(BaseMessage):
# Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"):
updated: List = []
updated: list = []
for tc in tool_calls:
updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
@ -178,7 +178,7 @@ class AIMessage(BaseMessage):
base = super().pretty_repr(html=html)
lines = []
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]:
lines = [
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
f" Call ID: {tc.get('id')}",
@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"""The type of the message (used for deserialization).
Defaults to "AIMessageChunk"."""
tool_call_chunks: List[ToolCallChunk] = []
tool_call_chunks: list[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns:
@ -232,7 +232,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
return ["langchain", "schema", "messages"]
@property
def lc_attributes(self) -> Dict:
def lc_attributes(self) -> dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator
@ -143,7 +144,7 @@ def merge_content(
merged = [merged] + content # type: ignore
elif isinstance(content, list):
# If both are lists
merged = merge_lists(cast(List, merged), content) # type: ignore
merged = merge_lists(cast(list, merged), content) # type: ignore
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -18,7 +18,7 @@ class ChatMessage(BaseMessage):
"""The type of the message (used during serialization). Defaults to "chat"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""
@ -39,7 +39,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
Defaults to "ChatMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].
"""

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal
from typing import Any, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -26,7 +26,7 @@ class FunctionMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "function"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
@ -46,7 +46,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
Defaults to "FunctionMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Union
from typing import Any, Literal, Union
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -39,13 +39,13 @@ class HumanMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "human"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg.
@ -70,7 +70,7 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
Defaults to "HumanMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal
from typing import Any, Literal
from langchain_core.messages.base import BaseMessage
@ -25,7 +25,7 @@ class RemoveMessage(BaseMessage):
return super().__init__("", id=id, **kwargs)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Union
from typing import Any, Literal, Union
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -33,13 +33,13 @@ class SystemMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "system"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
"""Pass in content as positional arg.
@ -64,7 +64,7 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
Defaults to "SystemMessageChunk"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Literal, Optional, Union
from uuid import UUID
from pydantic import Field, model_validator
@ -78,7 +78,7 @@ class ToolMessage(BaseMessage):
"""Currently inherited from BaseMessage, but not used."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
return ["langchain", "schema", "messages"]
@ -123,7 +123,7 @@ class ToolMessage(BaseMessage):
return values
def __init__(
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
) -> None:
super().__init__(content=content, **kwargs)
@ -140,7 +140,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
@ -187,7 +187,7 @@ class ToolCall(TypedDict):
name: str
"""The name of the tool to be called."""
args: Dict[str, Any]
args: dict[str, Any]
"""The arguments to the tool call."""
id: Optional[str]
"""An identifier associated with the tool call.
@ -198,7 +198,7 @@ class ToolCall(TypedDict):
type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")
@ -276,8 +276,8 @@ def invalid_tool_call(
def default_tool_parser(
raw_tool_calls: List[dict],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
raw_tool_calls: list[dict],
) -> tuple[list[ToolCall], list[InvalidToolCall]]:
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
@ -306,7 +306,7 @@ def default_tool_parser(
return tool_calls, invalid_tool_calls
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
def default_tool_chunk_parser(raw_tool_calls: list[dict]) -> list[ToolCallChunk]:
"""Best-effort parsing of tool chunks."""
tool_call_chunks = []
for tool_call in raw_tool_calls:

View File

@ -11,25 +11,21 @@ from __future__ import annotations
import inspect
import json
from collections.abc import Iterable, Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
overload,
)
from pydantic import Discriminator, Field, Tag
from typing_extensions import Annotated
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -198,7 +194,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[
BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any]
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
]

View File

@ -2,12 +2,11 @@ from __future__ import annotations
import json
from json import JSONDecodeError
from typing import Any, Optional, TypeVar, Union
from typing import Annotated, Any, Optional, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic
from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS

View File

@ -3,8 +3,9 @@ from __future__ import annotations
import re
from abc import abstractmethod
from collections import deque
from typing import AsyncIterator, Iterator, List, TypeVar, Union
from collections.abc import AsyncIterator, Iterator
from typing import Optional as Optional
from typing import TypeVar, Union
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -29,7 +30,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
yield buffer.popleft()
class ListOutputParser(BaseTransformOutputParser[List[str]]):
class ListOutputParser(BaseTransformOutputParser[list[str]]):
"""Parse the output of an LLM call to a list."""
@property

View File

@ -1,6 +1,7 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
from types import GenericAlias
from typing import Any, Optional, Union
import jsonpatch # type: ignore[import]
from pydantic import BaseModel, model_validator
@ -20,7 +21,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -72,7 +73,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -166,7 +167,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -223,7 +224,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
@ -232,7 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: Dict) -> Any:
def validate_schema(cls, values: dict) -> Any:
"""Validate the pydantic schema.
Args:
@ -246,17 +247,19 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
schema, BaseModel
values["args_only"] = (
isinstance(schema, type)
and not isinstance(schema, GenericAlias)
and issubclass(schema, BaseModel)
)
elif values["args_only"] and isinstance(schema, Dict):
elif values["args_only"] and isinstance(schema, dict):
raise ValueError(
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
@ -292,7 +295,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:

View File

@ -1,10 +1,9 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional
from typing import Annotated, Any, Optional
from pydantic import SkipValidation, ValidationError
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
@ -17,12 +16,12 @@ from langchain_core.utils.pydantic import TypeBaseModel
def parse_tool_call(
raw_tool_call: Dict[str, Any],
raw_tool_call: dict[str, Any],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
"""Parse a single tool call.
Args:
@ -69,7 +68,7 @@ def parse_tool_call(
def make_invalid_tool_call(
raw_tool_call: Dict[str, Any],
raw_tool_call: dict[str, Any],
error_msg: Optional[str],
) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call.
@ -90,12 +89,12 @@ def make_invalid_tool_call(
def parse_tool_calls(
raw_tool_calls: List[dict],
raw_tool_calls: list[dict],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Parse a list of tool calls.
Args:
@ -111,7 +110,7 @@ def parse_tool_calls(
Raises:
OutputParserException: If any of the tool calls are not valid JSON.
"""
final_tools: List[Dict[str, Any]] = []
final_tools: list[dict[str, Any]] = []
exceptions = []
for tool_call in raw_tool_calls:
try:
@ -151,7 +150,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
If no tool calls are found, None will be returned.
"""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@ -217,7 +216,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str
"""The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
@ -254,12 +253,12 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: Annotated[List[TypeBaseModel], SkipValidation()]
tools: Annotated[list[TypeBaseModel], SkipValidation()]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present.
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects.
Args:

View File

@ -1,9 +1,8 @@
import json
from typing import Generic, List, Optional, Type
from typing import Annotated, Generic, Optional
import pydantic
from pydantic import SkipValidation
from typing_extensions import Annotated
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
@ -18,7 +17,7 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore
"""The pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
@ -50,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return OutputParserException(msg, llm_output=json_string)
def parse_result(
self, result: List[Generation], *, partial: bool = False
self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]:
"""Parse the result of an LLM call to a pydantic object.
@ -108,7 +107,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return "pydantic"
@property
def OutputType(self) -> Type[TBaseModel]:
def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model."""
return self.pydantic_object

View File

@ -1,4 +1,3 @@
from typing import List
from typing import Optional as Optional
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -13,7 +12,7 @@ class StrOutputParser(BaseTransformOutputParser[str]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output_parser"]

View File

@ -1,10 +1,9 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
Optional,
Union,
)

View File

@ -1,7 +1,8 @@
import re
import xml
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder
from langchain_core.exceptions import OutputParserException
@ -57,7 +58,7 @@ class _StreamingParser:
_parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
self.current_path: List[str] = []
self.current_path: list[str] = []
self.current_path_has_children = False
self.buffer = ""
self.xml_started = False
@ -140,7 +141,7 @@ class _StreamingParser:
class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
)
@ -169,7 +170,7 @@ class XMLOutputParser(BaseTransformOutputParser):
"""Return the format instructions for the XML output."""
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
def parse(self, text: str) -> dict[str, Union[str, list[Any]]]:
"""Parse the output of an LLM call.
Args:
@ -234,13 +235,13 @@ class XMLOutputParser(BaseTransformOutputParser):
yield output
streaming_parser.close()
def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]:
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
"""Converts xml tree to python dictionary."""
if root.text and bool(re.search(r"\S", root.text)):
# If root text contains any non-whitespace character it
# returns {root.tag: root.text}
return {root.tag: root.text}
result: Dict = {root.tag: []}
result: dict = {root.tag: []}
for child in root:
if len(child) == 0:
result[root.tag].append({child.tag: child.text})
@ -253,7 +254,7 @@ class XMLOutputParser(BaseTransformOutputParser):
return "xml"
def nested_element(path: List[str], elem: ET.Element) -> Any:
def nested_element(path: list[str], elem: ET.Element) -> Any:
"""Get nested element from path.
Args:

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel
@ -18,7 +18,7 @@ class ChatResult(BaseModel):
for more information.
"""
generations: List[ChatGeneration]
generations: list[ChatGeneration]
"""List of the chat generations.
Generations is a list to allow for multiple candidate generations for a single

View File

@ -7,7 +7,8 @@ They can be used to represent text, images, or chat message pieces.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Literal, Sequence, cast
from collections.abc import Sequence
from typing import Literal, cast
from typing_extensions import TypedDict

View File

@ -1,16 +1,16 @@
from __future__ import annotations
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Mapping,
Optional,
TypeVar,
Union,
@ -39,7 +39,7 @@ FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""
@ -50,7 +50,7 @@ class BasePromptTemplate(
"""optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them."""
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
@ -60,7 +60,7 @@ class BasePromptTemplate(
Partial variables populate the template so that you don't need to
pass them in every time you call the prompt."""
metadata: Optional[Dict[str, Any]] = None # noqa: UP006
metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing."""
tags: Optional[list[str]] = None
"""Tags to be used for tracing."""

View File

@ -3,15 +3,13 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import (
Annotated,
Any,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
@ -25,7 +23,6 @@ from pydantic import (
SkipValidation,
model_validator,
)
from typing_extensions import Annotated
from langchain_core._api import deprecated
from langchain_core.load import Serializable
@ -816,9 +813,9 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
MessageLikeRepresentation = Union[
MessageLike,
Tuple[
Union[str, Type],
Union[str, List[dict], List[object]],
tuple[
Union[str, type],
Union[str, list[dict], list[object]],
],
str,
]
@ -1017,7 +1014,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
),
**kwargs,
}
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
@classmethod
def get_lc_namespace(cls) -> list[str]:
@ -1083,7 +1080,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
values["partial_variables"][message.variable_name] = []
optional_variables.add(message.variable_name)
if message.variable_name not in input_types:
input_types[message.variable_name] = List[AnyMessage]
input_types[message.variable_name] = list[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if optional_variables:

View File

@ -1,7 +1,7 @@
"""Prompt template that contains few shot examples."""
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
@ -16,7 +16,7 @@ from langchain_core.prompts.string import (
class FewShotPromptWithTemplates(StringPromptTemplate):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
examples: Optional[list[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
@ -43,13 +43,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"]
@model_validator(mode="before")
@classmethod
def check_examples_and_selector(cls, values: Dict) -> Any:
def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
@ -93,7 +93,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
extra="forbid",
)
def _get_examples(self, **kwargs: Any) -> List[dict]:
def _get_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
@ -101,7 +101,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
else:
raise ValueError
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any
from pydantic import Field
@ -33,7 +33,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
return "image-prompt"
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "image"]

View File

@ -3,7 +3,7 @@
import json
import logging
from pathlib import Path
from typing import Callable, Dict, Optional, Union
from typing import Callable, Optional, Union
import yaml
@ -181,7 +181,7 @@ def _load_prompt_from_file(
return load_prompt_from_config(config)
def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load chat prompt from config"""
messages = config.pop("messages")
@ -194,7 +194,7 @@ def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
return ChatPromptTemplate.from_template(template=template, **config)
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {
"prompt": _load_prompt,
"few_shot": _load_few_shot_prompt,
"chat": _load_chat_prompt,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any
from typing import Optional as Optional
from pydantic import model_validator
@ -8,7 +8,7 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
return {k: inputs[k] for k in input_variables}
@ -28,17 +28,17 @@ class PipelinePromptTemplate(BasePromptTemplate):
final_prompt: BasePromptTemplate
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
pipeline_prompts: list[tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"]
@model_validator(mode="before")
@classmethod
def get_input_variables(cls, values: Dict) -> Any:
def get_input_variables(cls, values: dict) -> Any:
"""Get input variables."""
created_variables = set()
all_variables = set()

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict
from typing import Any, Callable
from pydantic import BaseModel, create_model
@ -139,7 +139,7 @@ def mustache_template_vars(
return vars
Defs = Dict[str, "Defs"]
Defs = dict[str, "Defs"]
def mustache_schema(

View File

@ -1,13 +1,8 @@
from collections.abc import Iterator, Mapping, Sequence
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
@ -32,16 +27,16 @@ from langchain_core.utils import get_pydantic_field_names
class StructuredPrompt(ChatPromptTemplate):
"""Structured prompt template for a language model."""
schema_: Union[Dict, Type[BaseModel]]
schema_: Union[dict, type[BaseModel]]
"""Schema for the structured prompt."""
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
structured_output_kwargs: dict[str, Any] = Field(default_factory=dict)
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
schema_: Optional[Union[dict, type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[Dict[str, Any]] = None,
structured_output_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
@ -56,7 +51,7 @@ class StructuredPrompt(ChatPromptTemplate):
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
For example, if the class is `langchain.llms.openai.OpenAI`, then the
@ -68,7 +63,7 @@ class StructuredPrompt(ChatPromptTemplate):
def from_messages_and_schema(
cls,
messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]],
schema: Union[dict, type[BaseModel]],
**kwargs: Any,
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@ -118,7 +113,7 @@ class StructuredPrompt(ChatPromptTemplate):
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Dict, Other]:
) -> RunnableSerializable[dict, Other]:
return self.pipe(other)
def pipe(
@ -130,7 +125,7 @@ class StructuredPrompt(ChatPromptTemplate):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]:
) -> RunnableSerializable[dict, Other]:
"""Pipe the structured prompt to a language model.
Args:

View File

@ -24,7 +24,7 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, Optional
from pydantic import ConfigDict
from typing_extensions import TypedDict
@ -47,7 +47,7 @@ if TYPE_CHECKING:
)
RetrieverInput = str
RetrieverOutput = List[Document]
RetrieverOutput = list[Document]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]

View File

@ -6,36 +6,37 @@ import functools
import inspect
import threading
from abc import ABC, abstractmethod
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Awaitable,
Coroutine,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator
from langchain_core.load.serializable import (
@ -340,7 +341,11 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@ -408,7 +413,11 @@ class Runnable(Generic[Input, Output], ABC):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@ -771,10 +780,10 @@ class Runnable(Generic[Input, Output], ABC):
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(inputs[0], configs[0])])
return cast(list[Output], [invoke(inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
return cast(list[Output], list(executor.map(invoke, inputs, configs)))
@overload
def batch_as_completed(
@ -2024,7 +2033,7 @@ class Runnable(Generic[Input, Output], ABC):
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in input])
return cast(list[Output], [e for _ in input])
else:
raise
else:
@ -2036,7 +2045,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
return cast(list[Output], output)
else:
raise first_exception
@ -2099,7 +2108,7 @@ class Runnable(Generic[Input, Output], ABC):
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
if return_exceptions:
return cast(List[Output], [e for _ in input])
return cast(list[Output], [e for _ in input])
else:
raise
else:
@ -2113,7 +2122,7 @@ class Runnable(Generic[Input, Output], ABC):
coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
return cast(list[Output], output)
else:
raise first_exception
@ -3171,7 +3180,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
for rm in run_managers:
rm.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
return cast(list[Output], [e for _ in inputs])
else:
raise
else:
@ -3183,7 +3192,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
else:
run_manager.on_chain_end(out)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
return cast(list[Output], inputs)
else:
raise first_exception
@ -3298,7 +3307,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
return cast(list[Output], [e for _ in inputs])
else:
raise
else:
@ -3312,7 +3321,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
coros.append(run_manager.on_chain_end(out))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
return cast(list[Output], inputs)
else:
raise first_exception
@ -3420,7 +3429,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
yield chunk
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping
of their outputs.
@ -4071,7 +4080,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@ -4106,7 +4119,11 @@ class RunnableGenerator(Runnable[Input, Output]):
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@ -4369,7 +4386,7 @@ class RunnableLambda(Runnable[Input, Output]):
module = getattr(func, "__module__", None)
return create_model_v2(
self.get_name("Input"),
root=List[Any],
root=list[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
@ -4420,7 +4437,11 @@ class RunnableLambda(Runnable[Input, Output]):
func = getattr(self, "func", None) or self.afunc
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(
@ -4921,7 +4942,7 @@ class RunnableLambda(Runnable[Input, Output]):
yield chunk
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
"""Runnable that delegates calls to another Runnable
with each element of the input sequence.
@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
return list[self.bound.InputType] # type: ignore[name-defined]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
@ -4946,7 +4967,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return create_model_v2(
self.get_name("Input"),
root=(
List[self.bound.get_input_schema(config)], # type: ignore
list[self.bound.get_input_schema(config)], # type: ignore
None,
),
# create model needs access to appropriate type annotations to be
@ -4961,7 +4982,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@property
def OutputType(self) -> type[list[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
return list[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema(
self, config: Optional[RunnableConfig] = None
@ -4969,7 +4990,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
schema = self.bound.get_output_schema(config)
return create_model_v2(
self.get_name("Output"),
root=List[schema], # type: ignore[valid-type]
root=list[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
@ -5255,7 +5276,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property
def InputType(self) -> type[Input]:
return (
cast(Type[Input], self.custom_input_type)
cast(type[Input], self.custom_input_type)
if self.custom_input_type is not None
else self.bound.InputType
)
@ -5263,7 +5284,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@property
def OutputType(self) -> type[Output]:
return (
cast(Type[Output], self.custom_output_type)
cast(type[Output], self.custom_output_type)
if self.custom_output_type is not None
else self.bound.OutputType
)
@ -5336,7 +5357,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@ -5358,7 +5379,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> list[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@ -5400,7 +5421,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> Iterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
@ -5451,7 +5472,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
if isinstance(config, Sequence):
configs = cast(
List[RunnableConfig],
list[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:

View File

@ -1,15 +1,8 @@
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
@ -149,13 +142,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
@ -172,7 +165,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
return super().get_input_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import uuid
import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
@ -10,14 +11,8 @@ from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
@ -43,7 +38,7 @@ if TYPE_CHECKING:
else:
# Pydantic validates through typed dicts, but
# the callbacks need forward refs updated
Callbacks = Optional[Union[List, Any]]
Callbacks = Optional[Union[list, Any]]
class EmptyDict(TypedDict, total=False):

View File

@ -3,20 +3,16 @@ from __future__ import annotations
import enum
import threading
from abc import abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from collections.abc import Mapping as Mapping
from functools import wraps
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from typing import Mapping as Mapping
from weakref import WeakValueDictionary
from pydantic import BaseModel, ConfigDict
@ -176,10 +172,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return cast(List[Output], [invoke(prepared[0], inputs[0])])
return cast(list[Output], [invoke(prepared[0], inputs[0])])
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
return cast(list[Output], list(executor.map(invoke, prepared, inputs)))
async def abatch(
self,
@ -562,7 +558,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
for v in list(self.alternatives.keys()) + [self.default_key]
),
)
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
_enums_for_spec[self.which] = cast(type[StrEnum], which_enum)
return get_unique_config_specs(
# which alternative
[
@ -694,7 +690,7 @@ def make_options_spec(
spec.name or spec.id,
((v, v) for v in list(spec.options.keys())),
)
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
_enums_for_spec[spec] = cast(type[StrEnum], enum)
if isinstance(spec, ConfigurableFieldSingleOption):
return ConfigurableFieldSpec(
id=spec.id,

View File

@ -1,19 +1,13 @@
import asyncio
import inspect
import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@ -96,7 +90,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""The Runnable to run first."""
fallbacks: Sequence[Runnable[Input, Output]]
"""A sequence of fallbacks to try."""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,)
"""The exceptions on which fallbacks should be tried.
Any exception that is not a subclass of these exceptions will be raised immediately.
@ -112,25 +106,25 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
)
@property
def InputType(self) -> Type[Input]:
def InputType(self) -> type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> type[Output]:
return self.runnable.OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return self.runnable.get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]
@ -142,7 +136,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@ -252,12 +246,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all(
@ -296,9 +290,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for cm, input, config in zip(callback_managers, inputs, configs)
]
to_return: Dict[int, Any] = {}
to_return: dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = runnable.batch(
@ -344,12 +338,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
) -> list[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all(
@ -378,7 +372,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
None,
@ -392,7 +386,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
handled_exceptions: dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = await runnable.abatch(

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import inspect
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@ -11,7 +12,6 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Sequence,
TypedDict,
Union,
overload,

View File

@ -3,7 +3,8 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
import math
import os
from typing import Any, Mapping, Sequence
from collections.abc import Mapping, Sequence
from typing import Any
from langchain_core.runnables.graph import Edge as LangEdge

View File

@ -1,7 +1,7 @@
import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional
from typing import Optional
from langchain_core.runnables.graph import (
CurveStyle,
@ -15,8 +15,8 @@ MARKDOWN_SPECIAL_CHARS = "*_`"
def draw_mermaid(
nodes: Dict[str, Node],
edges: List[Edge],
nodes: dict[str, Node],
edges: list[Edge],
*,
first_node: Optional[str] = None,
last_node: Optional[str] = None,
@ -87,7 +87,7 @@ def draw_mermaid(
mermaid_graph += f"\t{node_label}\n"
# Group edges by their common prefixes
edge_groups: Dict[str, List[Edge]] = {}
edge_groups: dict[str, list[Edge]] = {}
for edge in edges:
src_parts = edge.source.split(":")
tgt_parts = edge.target.split(":")
@ -98,7 +98,7 @@ def draw_mermaid(
seen_subgraphs = set()
def add_subgraph(edges: List[Edge], prefix: str) -> None:
def add_subgraph(edges: list[Edge], prefix: str) -> None:
nonlocal mermaid_graph
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
if prefix and not self_loop:

View File

@ -1,13 +1,13 @@
from __future__ import annotations
import inspect
from collections.abc import Sequence
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Union,
)
@ -31,7 +31,7 @@ if TYPE_CHECKING:
from langchain_core.tracers.schemas import Run
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
@ -419,7 +419,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
if (
inspect.isclass(root_type)
and not isinstance(root_type, GenericAlias)
and issubclass(root_type, BaseModel)
):
return root_type
return create_model_v2(

View File

@ -5,15 +5,11 @@ from __future__ import annotations
import asyncio
import inspect
import threading
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Union,
cast,
@ -349,7 +345,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
The `RunnableAssign` class takes input dictionaries and, through a
@ -564,7 +560,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
yield cast(dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
@ -650,7 +646,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
yield chunk
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
"""Runnable that picks keys from Dict[str, Any] inputs.
RunnablePick class represents a Runnable that selectively picks keys from a

View File

@ -1,11 +1,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -98,7 +94,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
retryable_chain = chain.with_retry()
"""
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
"""The exception types to retry on. By default all exceptions are retried.
In general you should only retry on exceptions that are likely to be
@ -115,13 +111,13 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
"""The maximum number of attempts to retry the Runnable."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()
def _kwargs_retrying(self) -> dict[str, Any]:
kwargs: dict[str, Any] = dict()
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
@ -152,10 +148,10 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List["T"],
config: list[RunnableConfig],
run_manager: list["T"],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
) -> list[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
@ -208,17 +204,17 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def _batch(
self,
inputs: List[Input],
run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig],
inputs: list[Input],
run_manager: list["CallbackManagerForChainRun"],
config: list[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
not_set: list[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
@ -250,9 +246,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result)
except RetryError as e:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
@ -262,29 +258,29 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
async def _abatch(
self,
inputs: List[Input],
run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig],
inputs: list[Input],
run_manager: list["AsyncCallbackManagerForChainRun"],
config: list[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
not_set: list[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
@ -316,9 +312,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
attempt.retry_state.set_result(result)
except RetryError as e:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
result = cast(list[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
outputs: list[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
@ -328,12 +324,12 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[Input],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
) -> list[Output]:
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)

View File

@ -1,12 +1,9 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Optional,
Union,
cast,
@ -154,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
configs = get_config_list(config, len(inputs))
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Literal, Sequence, Union
from collections.abc import Sequence
from typing import Any, Literal, Union
from typing_extensions import NotRequired, TypedDict

View File

@ -6,23 +6,24 @@ import ast
import asyncio
import inspect
import textwrap
from collections.abc import (
AsyncIterable,
AsyncIterator,
Awaitable,
Coroutine,
Iterable,
Mapping,
Sequence,
)
from functools import lru_cache
from inspect import signature
from itertools import groupby
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Iterable,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
)
@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
class AddableDict(Dict[str, Any]):
class AddableDict(dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""

View File

@ -7,16 +7,11 @@ The primary goal of these storages is to support implementation of caching.
"""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
@ -84,7 +79,7 @@ class BaseStore(Generic[K, V], ABC):
"""
@abstractmethod
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Get the values associated with the given keys.
Args:
@ -95,7 +90,7 @@ class BaseStore(Generic[K, V], ABC):
If a key is not found, the corresponding value will be None.
"""
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Async get the values associated with the given keys.
Args:
@ -108,14 +103,14 @@ class BaseStore(Generic[K, V], ABC):
return await run_in_executor(None, self.mget, keys)
@abstractmethod
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
"""
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Async set the values for the given keys.
Args:
@ -184,9 +179,9 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
def __init__(self) -> None:
"""Initialize an empty store."""
self.store: Dict[str, V] = {}
self.store: dict[str, V] = {}
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
def mget(self, keys: Sequence[str]) -> list[Optional[V]]:
"""Get the values associated with the given keys.
Args:
@ -198,7 +193,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
"""
return [self.store.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
async def amget(self, keys: Sequence[str]) -> list[Optional[V]]:
"""Async get the values associated with the given keys.
Args:
@ -210,7 +205,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
"""
return self.mget(keys)
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
def mset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
"""Set the values for the given keys.
Args:
@ -222,7 +217,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
for key, value in key_value_pairs:
self.store[key] = value
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
async def amset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
"""Async set the values for the given keys.
Args:

View File

@ -3,8 +3,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Union
from pydantic import BaseModel

View File

@ -2,10 +2,10 @@
for debugging purposes.
"""
from typing import List, Sequence
from collections.abc import Sequence
def _get_sub_deps(packages: Sequence[str]) -> List[str]:
def _get_sub_deps(packages: Sequence[str]) -> list[str]:
"""Get any specified sub-dependencies."""
from importlib import metadata

View File

@ -7,15 +7,15 @@ import json
import uuid
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextvars import copy_context
from inspect import signature
from typing import (
Annotated,
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
TypeVar,
Union,
cast,
@ -36,7 +36,6 @@ from pydantic import (
)
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import Annotated
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -324,7 +323,7 @@ class ToolException(Exception):
pass
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:

View File

@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
@ -13,7 +13,7 @@ from langchain_core.tools.structured import StructuredTool
def tool(
*args: Union[str, Callable, Runnable],
return_direct: bool = False,
args_schema: Optional[Type] = None,
args_schema: Optional[type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
@ -160,7 +160,7 @@ def tool(
coroutine = ainvoke_wrapper
func = invoke_wrapper
schema: Optional[Type[BaseModel]] = runnable.input_schema
schema: Optional[type[BaseModel]] = runnable.input_schema
description = repr(runnable)
elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
@ -234,8 +234,8 @@ def _get_description_from_runnable(runnable: Runnable) -> str:
def _get_schema_from_runnable_and_arg_types(
runnable: Runnable,
name: str,
arg_types: Optional[Dict[str, Type]] = None,
) -> Type[BaseModel]:
arg_types: Optional[dict[str, type]] = None,
) -> type[BaseModel]:
"""Infer args_schema for tool."""
if arg_types is None:
try:
@ -252,11 +252,11 @@ def _get_schema_from_runnable_and_arg_types(
def convert_runnable_to_tool(
runnable: Runnable,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[type[BaseModel]] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
arg_types: Optional[Dict[str, Type]] = None,
arg_types: Optional[dict[str, type]] = None,
) -> BaseTool:
"""Convert a Runnable into a BaseTool.

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from inspect import signature
from typing import Callable, List
from typing import Callable
from langchain_core.tools.base import BaseTool
ToolsRenderer = Callable[[List[BaseTool]], str]
ToolsRenderer = Callable[[list[BaseTool]], str]
def render_text_description(tools: list[BaseTool]) -> str:

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from collections.abc import Awaitable
from inspect import signature
from typing import (
Any,
Awaitable,
Callable,
Optional,
Union,

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import textwrap
from collections.abc import Awaitable
from inspect import signature
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Literal,
Optional,
@ -12,7 +13,6 @@ from typing import (
)
from pydantic import BaseModel, Field, SkipValidation
from typing_extensions import Annotated
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,

View File

@ -1,7 +1,8 @@
"""Internal tracers used for stream_log and astream events implementations."""
import abc
from typing import AsyncIterator, Iterator, TypeVar
from collections.abc import AsyncIterator, Iterator
from typing import TypeVar
from uuid import UUID
T = TypeVar("T")

View File

@ -5,11 +5,11 @@ from __future__ import annotations
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Sequence,
Union,
)
from uuid import UUID

View File

@ -1,11 +1,11 @@
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Generator,
Optional,
Union,
cast,

View File

@ -6,14 +6,13 @@ import logging
import sys
import traceback
from abc import ABC, abstractmethod
from collections.abc import Coroutine, Sequence
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
Coroutine,
Literal,
Optional,
Sequence,
Union,
cast,
)

View File

@ -5,8 +5,9 @@ from __future__ import annotations
import logging
import threading
import weakref
from collections.abc import Sequence
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, List, Optional, Sequence, Union, cast
from typing import Any, Optional, Union, cast
from uuid import UUID
import langsmith
@ -156,7 +157,7 @@ class EvaluatorCallbackHandler(BaseTracer):
if isinstance(results, EvaluationResult):
results_ = [results]
elif isinstance(results, dict) and "results" in results:
results_ = cast(List[EvaluationResult], results["results"])
results_ = cast(list[EvaluationResult], results["results"])
else:
raise TypeError(
f"Invalid evaluation result type {type(results)}."

View File

@ -4,14 +4,11 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
cast,
@ -459,7 +456,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model":
generations = cast(List[List[ChatGenerationChunk]], response.generations)
generations = cast(list[list[ChatGenerationChunk]], response.generations)
for gen in generations:
if output != {}:
break
@ -469,7 +466,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
event = "on_chat_model_end"
elif run_info["run_type"] == "llm":
generations = cast(List[List[GenerationChunk]], response.generations)
generations = cast(list[list[GenerationChunk]], response.generations)
output = {
"generations": [
[

View File

@ -4,13 +4,11 @@ import asyncio
import copy
import threading
from collections import defaultdict
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
Any,
AsyncIterator,
Iterator,
Literal,
Optional,
Sequence,
TypeVar,
Union,
overload,

View File

@ -11,7 +11,8 @@ used in the code.
import asyncio
from asyncio import AbstractEventLoop, Queue
from typing import AsyncIterator, Generic, TypeVar
from collections.abc import AsyncIterator
from typing import Generic, TypeVar
T = TypeVar("T")

View File

@ -1,4 +1,5 @@
from typing import Awaitable, Callable, Optional, Union
from collections.abc import Awaitable
from typing import Callable, Optional, Union
from uuid import UUID
from langchain_core.runnables.config import (

View File

@ -1,6 +1,6 @@
"""A tracer that collects all nested runs in a list."""
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
from uuid import UUID
from langchain_core.tracers.base import BaseTracer
@ -38,7 +38,7 @@ class RunCollectorCallbackHandler(BaseTracer):
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.traced_runs: List[Run] = []
self.traced_runs: list[Run] = []
def _persist_run(self, run: Run) -> None:
"""

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Callable, List
from typing import Any, Callable
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
@ -54,7 +54,7 @@ class FunctionCallbackHandler(BaseTracer):
def _persist_run(self, run: Run) -> None:
pass
def get_parents(self, run: Run) -> List[Run]:
def get_parents(self, run: Run) -> list[Run]:
"""Get the parents of a run.
Args:

View File

@ -5,23 +5,20 @@ MIT License
"""
from collections import deque
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Deque,
Generic,
Iterator,
List,
)
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import (
Any,
Callable,
Generic,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -95,10 +92,10 @@ class NoLock:
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: Deque[T],
buffer: deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
peers: list[deque[T]],
lock: AbstractAsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`.
@ -191,10 +188,10 @@ class Tee(Generic[T]):
iterable: AsyncIterator[T],
n: int = 2,
*,
lock: Optional[AsyncContextManager[Any]] = None,
lock: Optional[AbstractAsyncContextManager[Any]] = None,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
@ -212,11 +209,11 @@ class Tee(Generic[T]):
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
@overload
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ...
def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
@ -267,7 +264,7 @@ class aclosing(AbstractAsyncContextManager):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
@ -277,7 +274,7 @@ class aclosing(AbstractAsyncContextManager):
async def abatch_iterate(
size: int, iterable: AsyncIterable[T]
) -> AsyncIterator[List[T]]:
) -> AsyncIterator[list[T]]:
"""Utility batching function for async iterables.
Args:
@ -287,7 +284,7 @@ async def abatch_iterate(
Returns:
An async iterator over the batches.
"""
batch: List[T] = []
batch: list[T] = []
async for element in iterable:
if len(batch) < size:
batch.append(element)

View File

@ -1,7 +1,8 @@
"""Utilities for formatting strings."""
from collections.abc import Mapping, Sequence
from string import Formatter
from typing import Any, List, Mapping, Sequence
from typing import Any
class StrictFormatter(Formatter):
@ -31,7 +32,7 @@ class StrictFormatter(Formatter):
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
self, format_string: str, input_variables: list[str]
) -> None:
"""Check that all input variables are used in the format string.

View File

@ -10,21 +10,17 @@ import typing
import uuid
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
@ -201,7 +197,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
from pydantic.v1 import BaseModel
model = cast(
Type[BaseModel],
type[BaseModel],
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
)
return convert_pydantic_to_openai_function(model) # type: ignore
@ -383,15 +379,15 @@ def convert_to_openai_function(
"parameters": function,
}
elif isinstance(function, type) and is_basemodel_subclass(function):
oai_function = cast(Dict, convert_pydantic_to_openai_function(function))
oai_function = cast(dict, convert_pydantic_to_openai_function(function))
elif is_typeddict(function):
oai_function = cast(
Dict, _convert_typed_dict_to_openai_function(cast(Type, function))
dict, _convert_typed_dict_to_openai_function(cast(type, function))
)
elif isinstance(function, BaseTool):
oai_function = cast(Dict, format_tool_to_openai_function(function))
oai_function = cast(dict, format_tool_to_openai_function(function))
elif callable(function):
oai_function = cast(Dict, convert_python_function_to_openai_function(function))
oai_function = cast(dict, convert_python_function_to_openai_function(function))
else:
raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
@ -598,17 +594,17 @@ def _py_38_safe_origin(origin: type) -> type:
)
origin_map: dict[type, Any] = {
dict: Dict,
list: List,
tuple: Tuple,
set: Set,
dict: dict,
list: list,
tuple: tuple,
set: set,
collections.abc.Iterable: typing.Iterable,
collections.abc.Mapping: typing.Mapping,
collections.abc.Sequence: typing.Sequence,
collections.abc.MutableMapping: typing.MutableMapping,
**origin_union_type_map,
}
return cast(Type, origin_map.get(origin, origin))
return cast(type, origin_map.get(origin, origin))
def _recursive_set_additional_properties_false(

View File

@ -1,6 +1,7 @@
import logging
import re
from typing import List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Optional, Union
from urllib.parse import urljoin, urlparse
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ DEFAULT_LINK_REGEX = (
def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]:
) -> list[str]:
"""Extract all links from a raw HTML string.
Args:
@ -56,7 +57,7 @@ def extract_sub_links(
prevent_outside: bool = True,
exclude_prefixes: Sequence[str] = (),
continue_on_failure: bool = False,
) -> List[str]:
) -> list[str]:
"""Extract all links from a raw HTML string and convert into absolute paths.
Args:

View File

@ -1,6 +1,6 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO
from typing import Optional, TextIO
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@ -12,8 +12,8 @@ _TEXT_COLOR_MAPPING = {
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
items: list[str], excluded_colors: Optional[list] = None
) -> dict[str, str]:
"""Get mapping for items to a support color.
Args:

View File

@ -1,16 +1,11 @@
from collections import deque
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice
from typing import (
Any,
ContextManager,
Deque,
Generator,
Generic,
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
@ -34,10 +29,10 @@ class NoLock:
def tee_peer(
iterator: Iterator[T],
# the buffer specific to this peer
buffer: Deque[T],
buffer: deque[T],
# the buffers of all peers, including our own
peers: List[Deque[T]],
lock: ContextManager[Any],
peers: list[deque[T]],
lock: AbstractContextManager[Any],
) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`.
@ -130,7 +125,7 @@ class Tee(Generic[T]):
iterable: Iterator[T],
n: int = 2,
*,
lock: Optional[ContextManager[Any]] = None,
lock: Optional[AbstractContextManager[Any]] = None,
):
"""Create a new ``tee``.
@ -141,7 +136,7 @@ class Tee(Generic[T]):
Defaults to None.
"""
self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
@ -159,11 +154,11 @@ class Tee(Generic[T]):
def __getitem__(self, item: int) -> Iterator[T]: ...
@overload
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: ...
def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ...
def __getitem__(
self, item: Union[int, slice]
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
) -> Union[Iterator[T], tuple[Iterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[Iterator[T]]:
@ -185,7 +180,7 @@ class Tee(Generic[T]):
safetee = Tee
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]:
"""Utility batching function.
Args:

Some files were not shown because too many files have changed in this diff Show More