mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
core: Put Python version as a project requirement so it is considered by ruff (#26608)
Ruff doesn't know about the python version in `[tool.poetry.dependencies]`. It can get it from `project.requires-python`. Notes: * poetry seems to have issues getting the python constraints from `requires-python` and using `python` in per dependency constraints. So I had to duplicate the info. I will open an issue on poetry. * `inspect.isclass()` doesn't work correctly with `GenericAlias` (`list[...]`, `dict[..., ...]`) on Python <3.11 so I added some `not isinstance(type, GenericAlias)` checks: Python 3.11 ```pycon >>> import inspect >>> inspect.isclass(list) True >>> inspect.isclass(list[str]) False ``` Python 3.9 ```pycon >>> import inspect >>> inspect.isclass(list) True >>> inspect.isclass(list[str]) True ``` Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
0f07cf61da
commit
a47b332841
@ -14,7 +14,8 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, TypeVar, Union, cast
|
||||
|
||||
from langchain_core._api.internal import is_caller_internal
|
||||
|
||||
@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
|
||||
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||
|
||||
|
||||
def beta(
|
||||
|
@ -14,11 +14,10 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
||||
|
||||
|
||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
|
||||
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
||||
|
||||
|
||||
def _validate_deprecation_params(
|
||||
@ -262,7 +261,7 @@ def deprecated(
|
||||
if not _obj_type:
|
||||
_obj_type = "attribute"
|
||||
wrapped = None
|
||||
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||
old_doc = obj.__doc__
|
||||
|
||||
class _deprecated_property(property):
|
||||
@ -304,7 +303,7 @@ def deprecated(
|
||||
)
|
||||
|
||||
else:
|
||||
_name = _name or cast(Union[Type, Callable], obj).__qualname__
|
||||
_name = _name or cast(Union[type, Callable], obj).__qualname__
|
||||
if not _obj_type:
|
||||
# edge case: when a function is within another function
|
||||
# within a test, this will call it a "method" not a "function"
|
||||
|
@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
|
@ -1,19 +1,13 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
|
||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||
|
||||
T = TypeVar("T")
|
||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
||||
Values = dict[Union[asyncio.Event, threading.Event], Any]
|
||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||
@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
|
||||
|
||||
def _config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
setter: Callable,
|
||||
getter: Callable,
|
||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
||||
event_cls: Union[type[threading.Event], type[asyncio.Event]],
|
||||
) -> RunnableConfig:
|
||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||
return config
|
||||
@ -99,10 +93,10 @@ def _config_with_context(
|
||||
}
|
||||
|
||||
values: Values = {}
|
||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||
event_cls
|
||||
)
|
||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
||||
context_funcs: dict[str, Callable[[], Any]] = {}
|
||||
for key, group in grouped_by_key.items():
|
||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||
@ -129,7 +123,7 @@ def _config_with_context(
|
||||
|
||||
def aconfig_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously patch a runnable config with context getters and setters.
|
||||
|
||||
@ -145,7 +139,7 @@ def aconfig_with_context(
|
||||
|
||||
def config_with_context(
|
||||
config: RunnableConfig,
|
||||
steps: List[Runnable],
|
||||
steps: list[Runnable],
|
||||
) -> RunnableConfig:
|
||||
"""Patch a runnable config with context getters and setters.
|
||||
|
||||
@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
|
||||
|
||||
prefix: str = ""
|
||||
|
||||
key: Union[str, List[str]]
|
||||
key: Union[str, list[str]]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ContextGet({_print_keys(self.key)})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||
return [
|
||||
@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return super().config_specs + [
|
||||
ConfigurableFieldSpec(
|
||||
id=id_,
|
||||
@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
|
||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||
|
||||
@property
|
||||
def ids(self) -> List[str]:
|
||||
def ids(self) -> list[str]:
|
||||
prefix = self.prefix + "/" if self.prefix else ""
|
||||
return [
|
||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||
@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
|
||||
]
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
mapper_config_specs = [
|
||||
s
|
||||
for mapper in self.keys.values()
|
||||
@ -364,7 +358,7 @@ class Context:
|
||||
return PrefixContext(prefix=scope)
|
||||
|
||||
@staticmethod
|
||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key)
|
||||
|
||||
@staticmethod
|
||||
@ -385,7 +379,7 @@ class PrefixContext:
|
||||
def __init__(self, prefix: str = ""):
|
||||
self.prefix = prefix
|
||||
|
||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
||||
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
|
||||
return ContextGet(key=key, prefix=self.prefix)
|
||||
|
||||
def setter(
|
||||
|
@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
@ -5,19 +5,15 @@ import functools
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -2352,7 +2348,7 @@ def _configure(
|
||||
and handler_class is not None
|
||||
)
|
||||
if var.get() is not None or create_one:
|
||||
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
|
||||
var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
|
||||
if handler_class is None:
|
||||
if not any(
|
||||
handler is var_handler # direct pointer comparison
|
||||
|
@ -18,7 +18,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.chat_sessions import ChatSession
|
||||
|
||||
@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
|
||||
An iterator of chat sessions.
|
||||
"""
|
||||
|
||||
def load(self) -> List[ChatSession]:
|
||||
def load(self) -> list[ChatSession]:
|
||||
"""Eagerly load the chat sessions into memory.
|
||||
|
||||
Returns:
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""**Chat Sessions** are a collection of messages and function calls."""
|
||||
|
||||
from typing import Sequence, TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
# Re-export Blob and PathLike for backwards compatibility
|
||||
from langchain_core.documents.base import Blob as Blob
|
||||
|
@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, Union
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langsmith import Client as LangSmithClient
|
||||
|
||||
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from typing import Any, Generator, Literal, Optional, Union, cast
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""**Embeddings** interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
@ -35,7 +34,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs.
|
||||
|
||||
Args:
|
||||
@ -46,7 +45,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
Args:
|
||||
@ -56,7 +55,7 @@ class Embeddings(ABC):
|
||||
Embedding.
|
||||
"""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronous Embed search docs.
|
||||
|
||||
Args:
|
||||
@ -67,7 +66,7 @@ class Embeddings(ABC):
|
||||
"""
|
||||
return await run_in_executor(None, self.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronous Embed query text.
|
||||
|
||||
Args:
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
# Please do not add additional fake embedding model implementations here.
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
def _get_embedding(self) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding()
|
||||
|
||||
|
||||
@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
size: int
|
||||
"""The size of the embedding vector."""
|
||||
|
||||
def _get_embedding(self, seed: int) -> List[float]:
|
||||
def _get_embedding(self, seed: int) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
# set the seed for the random generator
|
||||
@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
"""Get a seed for the random generator, using the hash of the text."""
|
||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._get_embedding(seed=self._get_seed(text))
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
|
||||
"""Interface for selecting examples to include in prompts."""
|
||||
|
||||
@abstractmethod
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
def add_example(self, example: dict[str, str]) -> Any:
|
||||
"""Add new example to store.
|
||||
|
||||
Args:
|
||||
example: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> Any:
|
||||
async def aadd_example(self, example: dict[str, str]) -> Any:
|
||||
"""Async add new example to store.
|
||||
|
||||
Args:
|
||||
@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
|
||||
return await run_in_executor(None, self.add_example, example)
|
||||
|
||||
@abstractmethod
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
input_variables: A dictionary with keys as input variables
|
||||
and values as their values."""
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the inputs.
|
||||
|
||||
Args:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Select examples based on length."""
|
||||
|
||||
import re
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
|
||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select examples based on length."""
|
||||
|
||||
examples: List[dict]
|
||||
examples: list[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
max_length: int = 2048
|
||||
"""Max length for the prompt, beyond which examples are cut."""
|
||||
|
||||
example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
|
||||
example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
|
||||
"""Length of each example."""
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
def add_example(self, example: dict[str, str]) -> None:
|
||||
"""Add new example to list.
|
||||
|
||||
Args:
|
||||
@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
string_example = self.example_prompt.format(**example)
|
||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||
|
||||
async def aadd_example(self, example: Dict[str, str]) -> None:
|
||||
async def aadd_example(self, example: dict[str, str]) -> None:
|
||||
"""Async add new example to list.
|
||||
|
||||
Args:
|
||||
@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
|
||||
return self
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||
i += 1
|
||||
return examples
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
"""Async select which examples to use based on the input lengths.
|
||||
|
||||
Args:
|
||||
|
@ -1,13 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Literal, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
@ -41,7 +42,7 @@ METADATA_LINKS_KEY = "links"
|
||||
|
||||
|
||||
@beta()
|
||||
def get_links(doc: Document) -> List[Link]:
|
||||
def get_links(doc: Document) -> list[Link]:
|
||||
"""Get the links from a document.
|
||||
|
||||
Args:
|
||||
|
@ -5,17 +5,13 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import abc
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Sequence, TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, TypedDict
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.documents import Document
|
||||
|
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
.. versionadded:: 0.2.29
|
||||
"""
|
||||
|
||||
store: Dict[str, Document] = Field(default_factory=dict)
|
||||
store: dict[str, Document] = Field(default_factory=dict)
|
||||
top_k: int = 4
|
||||
|
||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||
@ -43,7 +44,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
|
||||
return UpsertResponse(succeeded=ok_ids, failed=[])
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||
"""Delete by ID."""
|
||||
if ids is None:
|
||||
raise ValueError("IDs must be provided for deletion")
|
||||
@ -59,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
||||
)
|
||||
|
||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]:
|
||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
||||
"""Get by ids."""
|
||||
found_documents = []
|
||||
|
||||
@ -71,7 +72,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
counts_by_doc = []
|
||||
|
||||
for document in self.store.values():
|
||||
|
@ -1,16 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -52,7 +50,7 @@ class LangSmithParams(TypedDict, total=False):
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
@cache # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
"""Get a GPT-2 tokenizer instance.
|
||||
|
||||
@ -158,7 +156,7 @@ class BaseLanguageModel(
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
|
@ -3,21 +3,19 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import typing
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -1121,18 +1119,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
||||
tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, type], # noqa: UP006
|
||||
schema: Union[typing.Dict, type], # noqa: UP006
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
|
||||
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -14,7 +15,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
class FakeListLLM(LLM):
|
||||
"""Fake LLM for testing purposes."""
|
||||
|
||||
responses: List[str]
|
||||
responses: list[str]
|
||||
"""List of responses to return in order."""
|
||||
# This parameter should be removed from FakeListLLM since
|
||||
# it's only used by sub-classes.
|
||||
@ -37,7 +38,7 @@ class FakeListLLM(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -52,7 +53,7 @@ class FakeListLLM(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -90,7 +91,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
@ -110,7 +111,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
|
@ -3,7 +3,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -17,7 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List[BaseMessage]
|
||||
responses: list[BaseMessage]
|
||||
"""List of responses to **cycle** through in order."""
|
||||
sleep: Optional[float] = None
|
||||
"""Sleep time in seconds between responses."""
|
||||
@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -51,7 +52,7 @@ class FakeListChatModelError(Exception):
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List[str]
|
||||
responses: list[str]
|
||||
"""List of responses to **cycle** through in order."""
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
@ -65,8 +66,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -80,8 +81,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Union[list[str], None] = None,
|
||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -103,8 +104,8 @@ class FakeListChatModel(SimpleChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Union[list[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
@ -124,7 +125,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
@ -142,8 +143,8 @@ class FakeChatModel(SimpleChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -157,7 +158,7 @@ class FakeChatModel(SimpleChatModel):
|
||||
return "fake-chat-model"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"key": "fake"}
|
||||
|
||||
|
||||
@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
@ -202,8 +203,8 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
@ -231,7 +232,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
# Use a regular expression to split on whitespace with a capture group
|
||||
# so that we can preserve the whitespace in the output.
|
||||
assert isinstance(content, str)
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
content_chunks = cast(list[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
@ -249,7 +250,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
for fkey, fvalue in value.items():
|
||||
if isinstance(fvalue, str):
|
||||
# Break function call by `,`
|
||||
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
|
||||
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
|
@ -10,16 +10,12 @@ import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -448,7 +444,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
return cast(list[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
@ -494,7 +490,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
return cast(list[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
@ -883,13 +879,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
callbacks = cast(list[Callbacks], callbacks)
|
||||
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
list[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
@ -910,9 +906,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
cast(list[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
cast(dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
@ -1116,13 +1112,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
assert run_name is None or (
|
||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||
)
|
||||
callbacks = cast(List[Callbacks], callbacks)
|
||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
||||
callbacks = cast(list[Callbacks], callbacks)
|
||||
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||
metadata_list = cast(
|
||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||
)
|
||||
run_name_list = run_name or cast(
|
||||
List[Optional[str]], ([None] * len(prompts))
|
||||
list[Optional[str]], ([None] * len(prompts))
|
||||
)
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
@ -1143,9 +1139,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
cast(Callbacks, callbacks),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
cast(List[str], tags),
|
||||
cast(list[str], tags),
|
||||
self.tags,
|
||||
cast(Dict[str, Any], metadata),
|
||||
cast(dict[str, Any], metadata),
|
||||
self.metadata,
|
||||
)
|
||||
] * len(prompts)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.load.mapping import (
|
||||
@ -34,11 +34,11 @@ class Reviver:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[
|
||||
Dict[Tuple[str, ...], Tuple[str, ...]]
|
||||
dict[tuple[str, ...], tuple[str, ...]]
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Initialize the reviver.
|
||||
@ -73,7 +73,7 @@ class Reviver:
|
||||
else ALL_SERIALIZABLE_MAPPINGS
|
||||
)
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
def __call__(self, value: dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
@ -154,10 +154,10 @@ class Reviver:
|
||||
def loads(
|
||||
text: str,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
@ -190,10 +190,10 @@ def loads(
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
secrets_map: Optional[dict[str, str]] = None,
|
||||
valid_namespaces: Optional[list[str]] = None,
|
||||
secrets_from_env: bool = True,
|
||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
||||
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
@ -18,11 +18,9 @@ The mapping allows us to deserialize an AIMessage created with an older
|
||||
version of LangChain where the code was in a different location.
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
# First value is the value that it is serialized as
|
||||
# Second value is the path to load it from
|
||||
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain", "schema", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@ -535,7 +533,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
|
||||
# Needed for backwards compatibility for old versions of LangChain where things
|
||||
# Were in different place
|
||||
_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
_OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain", "schema", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@ -583,7 +581,7 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
|
||||
# Needed for backwards compatibility for a few versions where we serialized
|
||||
# with langchain_core paths.
|
||||
OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain_core", "messages", "ai", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
@ -937,7 +935,7 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
),
|
||||
}
|
||||
|
||||
_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
_JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
("langchain_core", "messages", "AIMessage"): (
|
||||
"langchain_core",
|
||||
"messages",
|
||||
|
@ -1,8 +1,6 @@
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
@ -25,9 +23,9 @@ class BaseSerialized(TypedDict):
|
||||
"""
|
||||
|
||||
lc: int
|
||||
id: List[str]
|
||||
id: list[str]
|
||||
name: NotRequired[str]
|
||||
graph: NotRequired[Dict[str, Any]]
|
||||
graph: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized):
|
||||
"""
|
||||
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
@ -125,7 +123,7 @@ class Serializable(BaseModel, ABC):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||
@ -134,7 +132,7 @@ class Serializable(BaseModel, ABC):
|
||||
return cls.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
@ -143,7 +141,7 @@ class Serializable(BaseModel, ABC):
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
@ -152,7 +150,7 @@ class Serializable(BaseModel, ABC):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def lc_id(cls) -> List[str]:
|
||||
def lc_id(cls) -> list[str]:
|
||||
"""A unique identifier for this class for serialization purposes.
|
||||
|
||||
The unique identifier is a list of strings that describes the path
|
||||
@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
root: dict[Any, Any], secrets_map: dict[str, str]
|
||||
) -> dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
"""
|
||||
_id: List[str] = []
|
||||
_id: list[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, TypedDict
|
||||
@ -69,9 +69,9 @@ class AIMessage(BaseMessage):
|
||||
At the moment, this is ignored by most models. Usage is discouraged.
|
||||
"""
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
"""If provided, tool calls associated with the message."""
|
||||
invalid_tool_calls: List[InvalidToolCall] = []
|
||||
invalid_tool_calls: list[InvalidToolCall] = []
|
||||
"""If provided, tool calls with parsing errors associated with the message."""
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
"""If provided, usage metadata for a message, such as token counts.
|
||||
@ -83,7 +83,7 @@ class AIMessage(BaseMessage):
|
||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
@ -94,7 +94,7 @@ class AIMessage(BaseMessage):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
Returns:
|
||||
@ -104,7 +104,7 @@ class AIMessage(BaseMessage):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
@ -137,7 +137,7 @@ class AIMessage(BaseMessage):
|
||||
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
updated: List = []
|
||||
updated: list = []
|
||||
for tc in tool_calls:
|
||||
updated.append(
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
@ -178,7 +178,7 @@ class AIMessage(BaseMessage):
|
||||
base = super().pretty_repr(html=html)
|
||||
lines = []
|
||||
|
||||
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
|
||||
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]:
|
||||
lines = [
|
||||
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
|
||||
f" Call ID: {tc.get('id')}",
|
||||
@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""The type of the message (used for deserialization).
|
||||
Defaults to "AIMessageChunk"."""
|
||||
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
tool_call_chunks: list[ToolCallChunk] = []
|
||||
"""If provided, tool call chunks associated with the message."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
Returns:
|
||||
@ -232,7 +232,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
def lc_attributes(self) -> dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
@ -143,7 +144,7 @@ def merge_content(
|
||||
merged = [merged] + content # type: ignore
|
||||
elif isinstance(content, list):
|
||||
# If both are lists
|
||||
merged = merge_lists(cast(List, merged), content) # type: ignore
|
||||
merged = merge_lists(cast(list, merged), content) # type: ignore
|
||||
# If the first content is a list, and the second content is a string
|
||||
else:
|
||||
# If the last element of the first content is a string
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -18,7 +18,7 @@ class ChatMessage(BaseMessage):
|
||||
"""The type of the message (used during serialization). Defaults to "chat"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"].
|
||||
"""
|
||||
@ -39,7 +39,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
Defaults to "ChatMessageChunk"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"].
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -26,7 +26,7 @@ class FunctionMessage(BaseMessage):
|
||||
"""The type of the message (used for serialization). Defaults to "function"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
@ -46,7 +46,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
Defaults to "FunctionMessageChunk"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@ -39,13 +39,13 @@ class HumanMessage(BaseMessage):
|
||||
"""The type of the message (used for serialization). Defaults to "human"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
@ -70,7 +70,7 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||
Defaults to "HumanMessageChunk"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages.base import BaseMessage
|
||||
|
||||
@ -25,7 +25,7 @@ class RemoveMessage(BaseMessage):
|
||||
return super().__init__("", id=id, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
|
||||
@ -33,13 +33,13 @@ class SystemMessage(BaseMessage):
|
||||
"""The type of the message (used for serialization). Defaults to "system"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass in content as positional arg.
|
||||
|
||||
@ -64,7 +64,7 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||
Defaults to "SystemMessageChunk"."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
@ -78,7 +78,7 @@ class ToolMessage(BaseMessage):
|
||||
"""Currently inherited from BaseMessage, but not used."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
Default is ["langchain", "schema", "messages"]."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
@ -123,7 +123,7 @@ class ToolMessage(BaseMessage):
|
||||
return values
|
||||
|
||||
def __init__(
|
||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
||||
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
@ -140,7 +140,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@ -187,7 +187,7 @@ class ToolCall(TypedDict):
|
||||
|
||||
name: str
|
||||
"""The name of the tool to be called."""
|
||||
args: Dict[str, Any]
|
||||
args: dict[str, Any]
|
||||
"""The arguments to the tool call."""
|
||||
id: Optional[str]
|
||||
"""An identifier associated with the tool call.
|
||||
@ -198,7 +198,7 @@ class ToolCall(TypedDict):
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||
|
||||
|
||||
@ -276,8 +276,8 @@ def invalid_tool_call(
|
||||
|
||||
|
||||
def default_tool_parser(
|
||||
raw_tool_calls: List[dict],
|
||||
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
|
||||
raw_tool_calls: list[dict],
|
||||
) -> tuple[list[ToolCall], list[InvalidToolCall]]:
|
||||
"""Best-effort parsing of tools."""
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
@ -306,7 +306,7 @@ def default_tool_parser(
|
||||
return tool_calls, invalid_tool_calls
|
||||
|
||||
|
||||
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
|
||||
def default_tool_chunk_parser(raw_tool_calls: list[dict]) -> list[ToolCallChunk]:
|
||||
"""Best-effort parsing of tool chunks."""
|
||||
tool_call_chunks = []
|
||||
for tool_call in raw_tool_calls:
|
||||
|
@ -11,25 +11,21 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||
@ -198,7 +194,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any]
|
||||
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
|
||||
]
|
||||
|
||||
|
||||
|
@ -2,12 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
from typing import AsyncIterator, Iterator, List, TypeVar, Union
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Optional as Optional
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@ -29,7 +30,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
||||
yield buffer.popleft()
|
||||
|
||||
|
||||
class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
||||
class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||
"""Parse the output of an LLM call to a list."""
|
||||
|
||||
@property
|
||||
|
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from types import GenericAlias
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import jsonpatch # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -20,7 +21,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
args_only: bool = True
|
||||
"""Whether to only return the arguments to the function call."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -72,7 +73,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||
return jsonpatch.make_patch(prev, next).patch
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -166,7 +167,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -223,7 +224,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
result = parser.parse_result([chat_generation])
|
||||
"""
|
||||
|
||||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
||||
pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
|
||||
"""The pydantic schema to parse the output with.
|
||||
|
||||
If multiple schemas are provided, then the function name will be used to
|
||||
@ -232,7 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_schema(cls, values: Dict) -> Any:
|
||||
def validate_schema(cls, values: dict) -> Any:
|
||||
"""Validate the pydantic schema.
|
||||
|
||||
Args:
|
||||
@ -246,17 +247,19 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""
|
||||
schema = values["pydantic_schema"]
|
||||
if "args_only" not in values:
|
||||
values["args_only"] = isinstance(schema, type) and issubclass(
|
||||
schema, BaseModel
|
||||
values["args_only"] = (
|
||||
isinstance(schema, type)
|
||||
and not isinstance(schema, GenericAlias)
|
||||
and issubclass(schema, BaseModel)
|
||||
)
|
||||
elif values["args_only"] and isinstance(schema, Dict):
|
||||
elif values["args_only"] and isinstance(schema, dict):
|
||||
raise ValueError(
|
||||
"If multiple pydantic schemas are provided then args_only should be"
|
||||
" False."
|
||||
)
|
||||
return values
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
@ -292,7 +295,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
attr_name: str
|
||||
"""The name of the attribute to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a JSON object.
|
||||
|
||||
Args:
|
||||
|
@ -1,10 +1,9 @@
|
||||
import copy
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
from pydantic import SkipValidation, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
@ -17,12 +16,12 @@ from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
|
||||
def parse_tool_call(
|
||||
raw_tool_call: Dict[str, Any],
|
||||
raw_tool_call: dict[str, Any],
|
||||
*,
|
||||
partial: bool = False,
|
||||
strict: bool = False,
|
||||
return_id: bool = True,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Parse a single tool call.
|
||||
|
||||
Args:
|
||||
@ -69,7 +68,7 @@ def parse_tool_call(
|
||||
|
||||
|
||||
def make_invalid_tool_call(
|
||||
raw_tool_call: Dict[str, Any],
|
||||
raw_tool_call: dict[str, Any],
|
||||
error_msg: Optional[str],
|
||||
) -> InvalidToolCall:
|
||||
"""Create an InvalidToolCall from a raw tool call.
|
||||
@ -90,12 +89,12 @@ def make_invalid_tool_call(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
raw_tool_calls: List[dict],
|
||||
raw_tool_calls: list[dict],
|
||||
*,
|
||||
partial: bool = False,
|
||||
strict: bool = False,
|
||||
return_id: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse a list of tool calls.
|
||||
|
||||
Args:
|
||||
@ -111,7 +110,7 @@ def parse_tool_calls(
|
||||
Raises:
|
||||
OutputParserException: If any of the tool calls are not valid JSON.
|
||||
"""
|
||||
final_tools: List[Dict[str, Any]] = []
|
||||
final_tools: list[dict[str, Any]] = []
|
||||
exceptions = []
|
||||
for tool_call in raw_tool_calls:
|
||||
try:
|
||||
@ -151,7 +150,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
If no tool calls are found, None will be returned.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@ -217,7 +216,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
key_name: str
|
||||
"""The type of tools to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a list of tool calls.
|
||||
|
||||
Args:
|
||||
@ -254,12 +253,12 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: Annotated[List[TypeBaseModel], SkipValidation()]
|
||||
tools: Annotated[list[TypeBaseModel], SkipValidation()]
|
||||
"""The tools to parse."""
|
||||
|
||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||
# Pydantic object fields are present.
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse the result of an LLM call to a list of Pydantic objects.
|
||||
|
||||
Args:
|
||||
|
@ -1,9 +1,8 @@
|
||||
import json
|
||||
from typing import Generic, List, Optional, Type
|
||||
from typing import Annotated, Generic, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
@ -18,7 +17,7 @@ from langchain_core.utils.pydantic import (
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
|
||||
pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore
|
||||
"""The pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
@ -50,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return OutputParserException(msg, llm_output=json_string)
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> Optional[TBaseModel]:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
|
||||
@ -108,7 +107,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
return "pydantic"
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[TBaseModel]:
|
||||
def OutputType(self) -> type[TBaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||
@ -13,7 +12,7 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "output_parser"]
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import re
|
||||
import xml
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from xml.etree.ElementTree import TreeBuilder
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -57,7 +58,7 @@ class _StreamingParser:
|
||||
_parser = None
|
||||
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
||||
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
||||
self.current_path: List[str] = []
|
||||
self.current_path: list[str] = []
|
||||
self.current_path_has_children = False
|
||||
self.buffer = ""
|
||||
self.xml_started = False
|
||||
@ -140,7 +141,7 @@ class _StreamingParser:
|
||||
class XMLOutputParser(BaseTransformOutputParser):
|
||||
"""Parse an output using xml format."""
|
||||
|
||||
tags: Optional[List[str]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
encoding_matcher: re.Pattern = re.compile(
|
||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
@ -169,7 +170,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
"""Return the format instructions for the XML output."""
|
||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
|
||||
def parse(self, text: str) -> dict[str, Union[str, list[Any]]]:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
Args:
|
||||
@ -234,13 +235,13 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
yield output
|
||||
streaming_parser.close()
|
||||
|
||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]:
|
||||
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
|
||||
"""Converts xml tree to python dictionary."""
|
||||
if root.text and bool(re.search(r"\S", root.text)):
|
||||
# If root text contains any non-whitespace character it
|
||||
# returns {root.tag: root.text}
|
||||
return {root.tag: root.text}
|
||||
result: Dict = {root.tag: []}
|
||||
result: dict = {root.tag: []}
|
||||
for child in root:
|
||||
if len(child) == 0:
|
||||
result[root.tag].append({child.tag: child.text})
|
||||
@ -253,7 +254,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
return "xml"
|
||||
|
||||
|
||||
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
||||
def nested_element(path: list[str], elem: ET.Element) -> Any:
|
||||
"""Get nested element from path.
|
||||
|
||||
Args:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -18,7 +18,7 @@ class ChatResult(BaseModel):
|
||||
for more information.
|
||||
"""
|
||||
|
||||
generations: List[ChatGeneration]
|
||||
generations: list[ChatGeneration]
|
||||
"""List of the chat generations.
|
||||
|
||||
Generations is a list to allow for multiple candidate generations for a single
|
||||
|
@ -7,7 +7,8 @@ They can be used to represent text, images, or chat message pieces.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -39,7 +39,7 @@ FormatOutputType = TypeVar("FormatOutputType")
|
||||
|
||||
|
||||
class BasePromptTemplate(
|
||||
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
|
||||
RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
|
||||
):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
@ -50,7 +50,7 @@ class BasePromptTemplate(
|
||||
"""optional_variables: A list of the names of the variables for placeholder
|
||||
or MessagePlaceholder that are optional. These variables are auto inferred
|
||||
from the prompt and user need not provide them."""
|
||||
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
|
||||
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
|
||||
"""A dictionary of the types of the variables the prompt template expects.
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
@ -60,7 +60,7 @@ class BasePromptTemplate(
|
||||
|
||||
Partial variables populate the template so that you don't need to
|
||||
pass them in every time you call the prompt."""
|
||||
metadata: Optional[Dict[str, Any]] = None # noqa: UP006
|
||||
metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
|
||||
"""Metadata to be used for tracing."""
|
||||
tags: Optional[list[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
@ -3,15 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -25,7 +23,6 @@ from pydantic import (
|
||||
SkipValidation,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.load import Serializable
|
||||
@ -816,9 +813,9 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[
|
||||
Union[str, Type],
|
||||
Union[str, List[dict], List[object]],
|
||||
tuple[
|
||||
Union[str, type],
|
||||
Union[str, list[dict], list[object]],
|
||||
],
|
||||
str,
|
||||
]
|
||||
@ -1017,7 +1014,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
),
|
||||
**kwargs,
|
||||
}
|
||||
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
|
||||
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
@ -1083,7 +1080,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
values["partial_variables"][message.variable_name] = []
|
||||
optional_variables.add(message.variable_name)
|
||||
if message.variable_name not in input_types:
|
||||
input_types[message.variable_name] = List[AnyMessage]
|
||||
input_types[message.variable_name] = list[AnyMessage]
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
if optional_variables:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
@ -16,7 +16,7 @@ from langchain_core.prompts.string import (
|
||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
examples: Optional[list[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
@ -43,13 +43,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
||||
def check_examples_and_selector(cls, values: dict) -> Any:
|
||||
"""Check that one and only one of examples/example_selector are provided."""
|
||||
examples = values.get("examples", None)
|
||||
example_selector = values.get("example_selector", None)
|
||||
@ -93,7 +93,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
def _get_examples(self, **kwargs: Any) -> list[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
@ -101,7 +101,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
||||
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -33,7 +33,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
return "image-prompt"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "image"]
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@ -181,7 +181,7 @@ def _load_prompt_from_file(
|
||||
return load_prompt_from_config(config)
|
||||
|
||||
|
||||
def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
|
||||
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||
"""Load chat prompt from config"""
|
||||
|
||||
messages = config.pop("messages")
|
||||
@ -194,7 +194,7 @@ def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
|
||||
return ChatPromptTemplate.from_template(template=template, **config)
|
||||
|
||||
|
||||
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
"prompt": _load_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
"chat": _load_chat_prompt,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
from typing import Optional as Optional
|
||||
|
||||
from pydantic import model_validator
|
||||
@ -8,7 +8,7 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
|
||||
|
||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
|
||||
return {k: inputs[k] for k in input_variables}
|
||||
|
||||
|
||||
@ -28,17 +28,17 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
|
||||
final_prompt: BasePromptTemplate
|
||||
"""The final prompt that is returned."""
|
||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
||||
pipeline_prompts: list[tuple[str, BasePromptTemplate]]
|
||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "pipeline"]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_input_variables(cls, values: Dict) -> Any:
|
||||
def get_input_variables(cls, values: dict) -> Any:
|
||||
"""Get input variables."""
|
||||
created_variables = set()
|
||||
all_variables = set()
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@ -139,7 +139,7 @@ def mustache_template_vars(
|
||||
return vars
|
||||
|
||||
|
||||
Defs = Dict[str, "Defs"]
|
||||
Defs = dict[str, "Defs"]
|
||||
|
||||
|
||||
def mustache_schema(
|
||||
|
@ -1,13 +1,8 @@
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -32,16 +27,16 @@ from langchain_core.utils import get_pydantic_field_names
|
||||
class StructuredPrompt(ChatPromptTemplate):
|
||||
"""Structured prompt template for a language model."""
|
||||
|
||||
schema_: Union[Dict, Type[BaseModel]]
|
||||
schema_: Union[dict, type[BaseModel]]
|
||||
"""Schema for the structured prompt."""
|
||||
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
structured_output_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
schema_: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
*,
|
||||
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
||||
structured_output_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
schema_ = schema_ or kwargs.pop("schema")
|
||||
@ -56,7 +51,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
|
||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||
@ -68,7 +63,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
def from_messages_and_schema(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
schema: Union[dict, type[BaseModel]],
|
||||
**kwargs: Any,
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
@ -118,7 +113,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
Callable[[Iterator[Any]], Iterator[Other]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
) -> RunnableSerializable[dict, Other]:
|
||||
return self.pipe(other)
|
||||
|
||||
def pipe(
|
||||
@ -130,7 +125,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
name: Optional[str] = None,
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
) -> RunnableSerializable[dict, Other]:
|
||||
"""Pipe the structured prompt to a language model.
|
||||
|
||||
Args:
|
||||
|
@ -24,7 +24,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
RetrieverInput = str
|
||||
RetrieverOutput = List[Document]
|
||||
RetrieverOutput = list[Document]
|
||||
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
||||
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
||||
|
||||
|
@ -6,36 +6,37 @@ import functools
|
||||
import inspect
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures import FIRST_COMPLETED, wait
|
||||
from contextvars import copy_context
|
||||
from functools import wraps
|
||||
from itertools import groupby, tee
|
||||
from operator import itemgetter
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_type_hints,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||
from typing_extensions import Literal, get_args, get_type_hints
|
||||
from typing_extensions import Literal, get_args
|
||||
|
||||
from langchain_core._api import beta_decorator
|
||||
from langchain_core.load.serializable import (
|
||||
@ -340,7 +341,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.InputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
@ -408,7 +413,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
@ -771,10 +780,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return cast(List[Output], [invoke(inputs[0], configs[0])])
|
||||
return cast(list[Output], [invoke(inputs[0], configs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
||||
return cast(list[Output], list(executor.map(invoke, inputs, configs)))
|
||||
|
||||
@overload
|
||||
def batch_as_completed(
|
||||
@ -2024,7 +2033,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_chain_error(e)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in input])
|
||||
return cast(list[Output], [e for _ in input])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
@ -2036,7 +2045,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
else:
|
||||
run_manager.on_chain_end(out)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], output)
|
||||
return cast(list[Output], output)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
@ -2099,7 +2108,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||
)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in input])
|
||||
return cast(list[Output], [e for _ in input])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
@ -2113,7 +2122,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
coros.append(run_manager.on_chain_end(out))
|
||||
await asyncio.gather(*coros)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], output)
|
||||
return cast(list[Output], output)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
@ -3171,7 +3180,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in inputs])
|
||||
return cast(list[Output], [e for _ in inputs])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
@ -3183,7 +3192,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
else:
|
||||
run_manager.on_chain_end(out)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], inputs)
|
||||
return cast(list[Output], inputs)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
@ -3298,7 +3307,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in inputs])
|
||||
return cast(list[Output], [e for _ in inputs])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
@ -3312,7 +3321,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
coros.append(run_manager.on_chain_end(out))
|
||||
await asyncio.gather(*coros)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], inputs)
|
||||
return cast(list[Output], inputs)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
@ -3420,7 +3429,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping
|
||||
of their outputs.
|
||||
|
||||
@ -4071,7 +4080,11 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
@ -4106,7 +4119,11 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
@ -4369,7 +4386,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
module = getattr(func, "__module__", None)
|
||||
return create_model_v2(
|
||||
self.get_name("Input"),
|
||||
root=List[Any],
|
||||
root=list[Any],
|
||||
# To create the schema, we need to provide the module
|
||||
# where the underlying function is defined.
|
||||
# This allows pydantic to resolve type annotations appropriately.
|
||||
@ -4420,7 +4437,11 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
@ -4921,7 +4942,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||
"""Runnable that delegates calls to another Runnable
|
||||
with each element of the input sequence.
|
||||
|
||||
@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||
return list[self.bound.InputType] # type: ignore[name-defined]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
@ -4946,7 +4967,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
return create_model_v2(
|
||||
self.get_name("Input"),
|
||||
root=(
|
||||
List[self.bound.get_input_schema(config)], # type: ignore
|
||||
list[self.bound.get_input_schema(config)], # type: ignore
|
||||
None,
|
||||
),
|
||||
# create model needs access to appropriate type annotations to be
|
||||
@ -4961,7 +4982,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
|
||||
@property
|
||||
def OutputType(self) -> type[list[Output]]:
|
||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||
return list[self.bound.OutputType] # type: ignore[name-defined]
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
@ -4969,7 +4990,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model_v2(
|
||||
self.get_name("Output"),
|
||||
root=List[schema], # type: ignore[valid-type]
|
||||
root=list[schema], # type: ignore[valid-type]
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
@ -5255,7 +5276,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
@property
|
||||
def InputType(self) -> type[Input]:
|
||||
return (
|
||||
cast(Type[Input], self.custom_input_type)
|
||||
cast(type[Input], self.custom_input_type)
|
||||
if self.custom_input_type is not None
|
||||
else self.bound.InputType
|
||||
)
|
||||
@ -5263,7 +5284,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
@property
|
||||
def OutputType(self) -> type[Output]:
|
||||
return (
|
||||
cast(Type[Output], self.custom_output_type)
|
||||
cast(type[Output], self.custom_output_type)
|
||||
if self.custom_output_type is not None
|
||||
else self.bound.OutputType
|
||||
)
|
||||
@ -5336,7 +5357,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
) -> list[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
list[RunnableConfig],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
@ -5358,7 +5379,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
) -> list[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
list[RunnableConfig],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
@ -5400,7 +5421,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[tuple[int, Union[Output, Exception]]]:
|
||||
if isinstance(config, Sequence):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
list[RunnableConfig],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
@ -5451,7 +5472,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
|
||||
if isinstance(config, Sequence):
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
list[RunnableConfig],
|
||||
[self._merge_configs(conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
|
@ -1,15 +1,8 @@
|
||||
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
branch.invoke(None) # "goodbye"
|
||||
"""
|
||||
|
||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
default: Runnable[Input, Output]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*branches: Union[
|
||||
Tuple[
|
||||
tuple[
|
||||
Union[
|
||||
Runnable[Input, bool],
|
||||
Callable[[Input], bool],
|
||||
@ -149,13 +142,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
@ -172,7 +165,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
return super().get_input_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
from langchain_core.beta.runnables.context import (
|
||||
CONTEXT_CONFIG_PREFIX,
|
||||
CONTEXT_CONFIG_SUFFIX_SET,
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
|
||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, copy_context
|
||||
@ -10,14 +11,8 @@ from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -43,7 +38,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
# Pydantic validates through typed dicts, but
|
||||
# the callbacks need forward refs updated
|
||||
Callbacks = Optional[Union[List, Any]]
|
||||
Callbacks = Optional[Union[list, Any]]
|
||||
|
||||
|
||||
class EmptyDict(TypedDict, total=False):
|
||||
|
@ -3,20 +3,16 @@ from __future__ import annotations
|
||||
import enum
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from collections.abc import Mapping as Mapping
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Mapping as Mapping
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -176,10 +172,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return cast(List[Output], [invoke(prepared[0], inputs[0])])
|
||||
return cast(list[Output], [invoke(prepared[0], inputs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
|
||||
return cast(list[Output], list(executor.map(invoke, prepared, inputs)))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
@ -562,7 +558,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
for v in list(self.alternatives.keys()) + [self.default_key]
|
||||
),
|
||||
)
|
||||
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
|
||||
_enums_for_spec[self.which] = cast(type[StrEnum], which_enum)
|
||||
return get_unique_config_specs(
|
||||
# which alternative
|
||||
[
|
||||
@ -694,7 +690,7 @@ def make_options_spec(
|
||||
spec.name or spec.id,
|
||||
((v, v) for v in list(spec.options.keys())),
|
||||
)
|
||||
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
|
||||
_enums_for_spec[spec] = cast(type[StrEnum], enum)
|
||||
if isinstance(spec, ConfigurableFieldSingleOption):
|
||||
return ConfigurableFieldSpec(
|
||||
id=spec.id,
|
||||
|
@ -1,19 +1,13 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextvars import copy_context
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -96,7 +90,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
"""The Runnable to run first."""
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
"""A sequence of fallbacks to try."""
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,)
|
||||
"""The exceptions on which fallbacks should be tried.
|
||||
|
||||
Any exception that is not a subclass of these exceptions will be raised immediately.
|
||||
@ -112,25 +106,25 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
def InputType(self) -> type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
def OutputType(self) -> type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return self.runnable.get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return self.runnable.get_output_schema(config)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
return get_unique_config_specs(
|
||||
spec
|
||||
for step in [self.runnable, *self.fallbacks]
|
||||
@ -142,7 +136,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@ -252,12 +246,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
if self.exception_key is not None and not all(
|
||||
@ -296,9 +290,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
to_return: Dict[int, Any] = {}
|
||||
to_return: dict[int, Any] = {}
|
||||
run_again = {i: input for i, input in enumerate(inputs)}
|
||||
handled_exceptions: Dict[int, BaseException] = {}
|
||||
handled_exceptions: dict[int, BaseException] = {}
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
outputs = runnable.batch(
|
||||
@ -344,12 +338,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if self.exception_key is not None and not all(
|
||||
@ -378,7 +372,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
None,
|
||||
@ -392,7 +386,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
to_return = {}
|
||||
run_again = {i: input for i, input in enumerate(inputs)}
|
||||
handled_exceptions: Dict[int, BaseException] = {}
|
||||
handled_exceptions: dict[int, BaseException] = {}
|
||||
first_to_raise = None
|
||||
for runnable in self.runnables:
|
||||
outputs = await runnable.abatch(
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
@ -11,7 +12,6 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
Union,
|
||||
overload,
|
||||
|
@ -3,7 +3,8 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables.graph import Edge as LangEdge
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.runnables.graph import (
|
||||
CurveStyle,
|
||||
@ -15,8 +15,8 @@ MARKDOWN_SPECIAL_CHARS = "*_`"
|
||||
|
||||
|
||||
def draw_mermaid(
|
||||
nodes: Dict[str, Node],
|
||||
edges: List[Edge],
|
||||
nodes: dict[str, Node],
|
||||
edges: list[Edge],
|
||||
*,
|
||||
first_node: Optional[str] = None,
|
||||
last_node: Optional[str] = None,
|
||||
@ -87,7 +87,7 @@ def draw_mermaid(
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
|
||||
# Group edges by their common prefixes
|
||||
edge_groups: Dict[str, List[Edge]] = {}
|
||||
edge_groups: dict[str, list[Edge]] = {}
|
||||
for edge in edges:
|
||||
src_parts = edge.source.split(":")
|
||||
tgt_parts = edge.target.split(":")
|
||||
@ -98,7 +98,7 @@ def draw_mermaid(
|
||||
|
||||
seen_subgraphs = set()
|
||||
|
||||
def add_subgraph(edges: List[Edge], prefix: str) -> None:
|
||||
def add_subgraph(edges: list[Edge], prefix: str) -> None:
|
||||
nonlocal mermaid_graph
|
||||
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
||||
if prefix and not self_loop:
|
||||
|
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
|
||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
|
||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||
|
||||
|
||||
@ -419,7 +419,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
if (
|
||||
inspect.isclass(root_type)
|
||||
and not isinstance(root_type, GenericAlias)
|
||||
and issubclass(root_type, BaseModel)
|
||||
):
|
||||
return root_type
|
||||
|
||||
return create_model_v2(
|
||||
|
@ -5,15 +5,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
@ -349,7 +345,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
|
||||
|
||||
|
||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
|
||||
The `RunnableAssign` class takes input dictionaries and, through a
|
||||
@ -564,7 +560,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
if filtered:
|
||||
yield filtered
|
||||
# yield map output
|
||||
yield cast(Dict[str, Any], first_map_chunk_future.result())
|
||||
yield cast(dict[str, Any], first_map_chunk_future.result())
|
||||
for chunk in map_output:
|
||||
yield chunk
|
||||
|
||||
@ -650,7 +646,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
"""Runnable that picks keys from Dict[str, Any] inputs.
|
||||
|
||||
RunnablePick class represents a Runnable that selectively picks keys from a
|
||||
|
@ -1,11 +1,7 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -98,7 +94,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
retryable_chain = chain.with_retry()
|
||||
"""
|
||||
|
||||
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
|
||||
"""The exception types to retry on. By default all exceptions are retried.
|
||||
|
||||
In general you should only retry on exceptions that are likely to be
|
||||
@ -115,13 +111,13 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
"""The maximum number of attempts to retry the Runnable."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
@property
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = dict()
|
||||
def _kwargs_retrying(self) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = dict()
|
||||
|
||||
if self.max_attempt_number:
|
||||
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
|
||||
@ -152,10 +148,10 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: List[RunnableConfig],
|
||||
run_manager: List["T"],
|
||||
config: list[RunnableConfig],
|
||||
run_manager: list["T"],
|
||||
retry_state: RetryCallState,
|
||||
) -> List[RunnableConfig]:
|
||||
) -> list[RunnableConfig]:
|
||||
return [
|
||||
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||
]
|
||||
@ -208,17 +204,17 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List["CallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
inputs: list[Input],
|
||||
run_manager: list["CallbackManagerForChainRun"],
|
||||
config: list[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
@ -250,9 +246,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
result = cast(list[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
outputs: list[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
@ -262,29 +258,29 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
return self._batch_with_config(
|
||||
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List["AsyncCallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
inputs: list[Input],
|
||||
run_manager: list["AsyncCallbackManagerForChainRun"],
|
||||
config: list[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
@ -316,9 +312,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
result = cast(list[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
outputs: list[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
@ -328,12 +324,12 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[Input],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
) -> list[Output]:
|
||||
return await self._abatch_with_config(
|
||||
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
@ -154,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
configs = get_config_list(config, len(inputs))
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output],
|
||||
list[Output],
|
||||
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
||||
)
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
|
@ -6,23 +6,24 @@ import ast
|
||||
import asyncio
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||
|
||||
|
||||
class AddableDict(Dict[str, Any]):
|
||||
class AddableDict(dict[str, Any]):
|
||||
"""
|
||||
Dictionary that can be added to another dictionary.
|
||||
"""
|
||||
|
@ -7,16 +7,11 @@ The primary goal of these storages is to support implementation of caching.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -84,7 +79,7 @@ class BaseStore(Generic[K, V], ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
||||
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
@ -95,7 +90,7 @@ class BaseStore(Generic[K, V], ABC):
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
|
||||
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
||||
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
"""Async get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
@ -108,14 +103,14 @@ class BaseStore(Generic[K, V], ABC):
|
||||
return await run_in_executor(None, self.mget, keys)
|
||||
|
||||
@abstractmethod
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||
"""
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Async set the values for the given keys.
|
||||
|
||||
Args:
|
||||
@ -184,9 +179,9 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty store."""
|
||||
self.store: Dict[str, V] = {}
|
||||
self.store: dict[str, V] = {}
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
def mget(self, keys: Sequence[str]) -> list[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
@ -198,7 +193,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
"""
|
||||
return [self.store.get(key) for key in keys]
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
async def amget(self, keys: Sequence[str]) -> list[Optional[V]]:
|
||||
"""Async get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
@ -210,7 +205,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
"""
|
||||
return self.mget(keys)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
def mset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
@ -222,7 +217,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
for key, value in key_value_pairs:
|
||||
self.store[key] = value
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
async def amset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
|
||||
"""Async set the values for the given keys.
|
||||
|
||||
Args:
|
||||
|
@ -3,8 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -2,10 +2,10 @@
|
||||
for debugging purposes.
|
||||
"""
|
||||
|
||||
from typing import List, Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def _get_sub_deps(packages: Sequence[str]) -> List[str]:
|
||||
def _get_sub_deps(packages: Sequence[str]) -> list[str]:
|
||||
"""Get any specified sub-dependencies."""
|
||||
from importlib import metadata
|
||||
|
||||
|
@ -7,15 +7,15 @@ import json
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from contextvars import copy_context
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -36,7 +36,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -324,7 +323,7 @@ class ToolException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
||||
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
|
||||
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@ -13,7 +13,7 @@ from langchain_core.tools.structured import StructuredTool
|
||||
def tool(
|
||||
*args: Union[str, Callable, Runnable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type] = None,
|
||||
args_schema: Optional[type] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
@ -160,7 +160,7 @@ def tool(
|
||||
|
||||
coroutine = ainvoke_wrapper
|
||||
func = invoke_wrapper
|
||||
schema: Optional[Type[BaseModel]] = runnable.input_schema
|
||||
schema: Optional[type[BaseModel]] = runnable.input_schema
|
||||
description = repr(runnable)
|
||||
elif inspect.iscoroutinefunction(dec_func):
|
||||
coroutine = dec_func
|
||||
@ -234,8 +234,8 @@ def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||
def _get_schema_from_runnable_and_arg_types(
|
||||
runnable: Runnable,
|
||||
name: str,
|
||||
arg_types: Optional[Dict[str, Type]] = None,
|
||||
) -> Type[BaseModel]:
|
||||
arg_types: Optional[dict[str, type]] = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Infer args_schema for tool."""
|
||||
if arg_types is None:
|
||||
try:
|
||||
@ -252,11 +252,11 @@ def _get_schema_from_runnable_and_arg_types(
|
||||
|
||||
def convert_runnable_to_tool(
|
||||
runnable: Runnable,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
args_schema: Optional[type[BaseModel]] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
arg_types: Optional[Dict[str, Type]] = None,
|
||||
arg_types: Optional[dict[str, type]] = None,
|
||||
) -> BaseTool:
|
||||
"""Convert a Runnable into a BaseTool.
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import signature
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
from langchain_core.tools.base import BaseTool
|
||||
|
||||
ToolsRenderer = Callable[[List[BaseTool]], str]
|
||||
ToolsRenderer = Callable[[list[BaseTool]], str]
|
||||
|
||||
|
||||
def render_text_description(tools: list[BaseTool]) -> str:
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from collections.abc import Awaitable
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
@ -12,7 +13,6 @@ from typing import (
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, SkipValidation
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Internal tracers used for stream_log and astream events implementations."""
|
||||
|
||||
import abc
|
||||
from typing import AsyncIterator, Iterator, TypeVar
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -5,11 +5,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generator,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
|
@ -6,14 +6,13 @@ import logging
|
||||
import sys
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Coroutine,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
import threading
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, List, Optional, Sequence, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import langsmith
|
||||
@ -156,7 +157,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
if isinstance(results, EvaluationResult):
|
||||
results_ = [results]
|
||||
elif isinstance(results, dict) and "results" in results:
|
||||
results_ = cast(List[EvaluationResult], results["results"])
|
||||
results_ = cast(list[EvaluationResult], results["results"])
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid evaluation result type {type(results)}."
|
||||
|
@ -4,14 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -459,7 +456,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
output: Union[dict, BaseMessage] = {}
|
||||
|
||||
if run_info["run_type"] == "chat_model":
|
||||
generations = cast(List[List[ChatGenerationChunk]], response.generations)
|
||||
generations = cast(list[list[ChatGenerationChunk]], response.generations)
|
||||
for gen in generations:
|
||||
if output != {}:
|
||||
break
|
||||
@ -469,7 +466,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
event = "on_chat_model_end"
|
||||
elif run_info["run_type"] == "llm":
|
||||
generations = cast(List[List[GenerationChunk]], response.generations)
|
||||
generations = cast(list[list[GenerationChunk]], response.generations)
|
||||
output = {
|
||||
"generations": [
|
||||
[
|
||||
|
@ -4,13 +4,11 @@ import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
|
@ -11,7 +11,8 @@ used in the code.
|
||||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop, Queue
|
||||
from typing import AsyncIterator, Generic, TypeVar
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Awaitable, Callable, Optional, Union
|
||||
from collections.abc import Awaitable
|
||||
from typing import Callable, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables.config import (
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""A tracer that collects all nested runs in a list."""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
@ -38,7 +38,7 @@ class RunCollectorCallbackHandler(BaseTracer):
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.traced_runs: List[Run] = []
|
||||
self.traced_runs: list[Run] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable
|
||||
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
@ -54,7 +54,7 @@ class FunctionCallbackHandler(BaseTracer):
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
pass
|
||||
|
||||
def get_parents(self, run: Run) -> List[Run]:
|
||||
def get_parents(self, run: Run) -> list[Run]:
|
||||
"""Get the parents of a run.
|
||||
|
||||
Args:
|
||||
|
@ -5,23 +5,20 @@ MIT License
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Deque,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
)
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -95,10 +92,10 @@ class NoLock:
|
||||
async def tee_peer(
|
||||
iterator: AsyncIterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
buffer: deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
lock: AsyncContextManager[Any],
|
||||
peers: list[deque[T]],
|
||||
lock: AbstractAsyncContextManager[Any],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`.
|
||||
|
||||
@ -191,10 +188,10 @@ class Tee(Generic[T]):
|
||||
iterable: AsyncIterator[T],
|
||||
n: int = 2,
|
||||
*,
|
||||
lock: Optional[AsyncContextManager[Any]] = None,
|
||||
lock: Optional[AbstractAsyncContextManager[Any]] = None,
|
||||
):
|
||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
tee_peer(
|
||||
iterator=self._iterator,
|
||||
@ -212,11 +209,11 @@ class Tee(Generic[T]):
|
||||
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ...
|
||||
def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
|
||||
) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
|
||||
return self._children[item]
|
||||
|
||||
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
||||
@ -267,7 +264,7 @@ class aclosing(AbstractAsyncContextManager):
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
@ -277,7 +274,7 @@ class aclosing(AbstractAsyncContextManager):
|
||||
|
||||
async def abatch_iterate(
|
||||
size: int, iterable: AsyncIterable[T]
|
||||
) -> AsyncIterator[List[T]]:
|
||||
) -> AsyncIterator[list[T]]:
|
||||
"""Utility batching function for async iterables.
|
||||
|
||||
Args:
|
||||
@ -287,7 +284,7 @@ async def abatch_iterate(
|
||||
Returns:
|
||||
An async iterator over the batches.
|
||||
"""
|
||||
batch: List[T] = []
|
||||
batch: list[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
batch.append(element)
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""Utilities for formatting strings."""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
@ -31,7 +32,7 @@ class StrictFormatter(Formatter):
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
self, format_string: str, input_variables: list[str]
|
||||
) -> None:
|
||||
"""Check that all input variables are used in the format string.
|
||||
|
||||
|
@ -10,21 +10,17 @@ import typing
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
|
||||
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
@ -201,7 +197,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
model = cast(
|
||||
Type[BaseModel],
|
||||
type[BaseModel],
|
||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||
)
|
||||
return convert_pydantic_to_openai_function(model) # type: ignore
|
||||
@ -383,15 +379,15 @@ def convert_to_openai_function(
|
||||
"parameters": function,
|
||||
}
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
oai_function = cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
oai_function = cast(dict, convert_pydantic_to_openai_function(function))
|
||||
elif is_typeddict(function):
|
||||
oai_function = cast(
|
||||
Dict, _convert_typed_dict_to_openai_function(cast(Type, function))
|
||||
dict, _convert_typed_dict_to_openai_function(cast(type, function))
|
||||
)
|
||||
elif isinstance(function, BaseTool):
|
||||
oai_function = cast(Dict, format_tool_to_openai_function(function))
|
||||
oai_function = cast(dict, format_tool_to_openai_function(function))
|
||||
elif callable(function):
|
||||
oai_function = cast(Dict, convert_python_function_to_openai_function(function))
|
||||
oai_function = cast(dict, convert_python_function_to_openai_function(function))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
||||
@ -598,17 +594,17 @@ def _py_38_safe_origin(origin: type) -> type:
|
||||
)
|
||||
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: Dict,
|
||||
list: List,
|
||||
tuple: Tuple,
|
||||
set: Set,
|
||||
dict: dict,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
collections.abc.Iterable: typing.Iterable,
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast(Type, origin_map.get(origin, origin))
|
||||
return cast(type, origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,7 +34,7 @@ DEFAULT_LINK_REGEX = (
|
||||
|
||||
def find_all_links(
|
||||
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Extract all links from a raw HTML string.
|
||||
|
||||
Args:
|
||||
@ -56,7 +57,7 @@ def extract_sub_links(
|
||||
prevent_outside: bool = True,
|
||||
exclude_prefixes: Sequence[str] = (),
|
||||
continue_on_failure: bool = False,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Extract all links from a raw HTML string and convert into absolute paths.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Handle chained inputs."""
|
||||
|
||||
from typing import Dict, List, Optional, TextIO
|
||||
from typing import Optional, TextIO
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@ -12,8 +12,8 @@ _TEXT_COLOR_MAPPING = {
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
items: list[str], excluded_colors: Optional[list] = None
|
||||
) -> dict[str, str]:
|
||||
"""Get mapping for items to a support color.
|
||||
|
||||
Args:
|
||||
|
@ -1,16 +1,11 @@
|
||||
from collections import deque
|
||||
from collections.abc import Generator, Iterable, Iterator
|
||||
from contextlib import AbstractContextManager
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
ContextManager,
|
||||
Deque,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
@ -34,10 +29,10 @@ class NoLock:
|
||||
def tee_peer(
|
||||
iterator: Iterator[T],
|
||||
# the buffer specific to this peer
|
||||
buffer: Deque[T],
|
||||
buffer: deque[T],
|
||||
# the buffers of all peers, including our own
|
||||
peers: List[Deque[T]],
|
||||
lock: ContextManager[Any],
|
||||
peers: list[deque[T]],
|
||||
lock: AbstractContextManager[Any],
|
||||
) -> Generator[T, None, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`.
|
||||
|
||||
@ -130,7 +125,7 @@ class Tee(Generic[T]):
|
||||
iterable: Iterator[T],
|
||||
n: int = 2,
|
||||
*,
|
||||
lock: Optional[ContextManager[Any]] = None,
|
||||
lock: Optional[AbstractContextManager[Any]] = None,
|
||||
):
|
||||
"""Create a new ``tee``.
|
||||
|
||||
@ -141,7 +136,7 @@ class Tee(Generic[T]):
|
||||
Defaults to None.
|
||||
"""
|
||||
self._iterator = iter(iterable)
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
tee_peer(
|
||||
iterator=self._iterator,
|
||||
@ -159,11 +154,11 @@ class Tee(Generic[T]):
|
||||
def __getitem__(self, item: int) -> Iterator[T]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: ...
|
||||
def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice]
|
||||
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
|
||||
) -> Union[Iterator[T], tuple[Iterator[T], ...]]:
|
||||
return self._children[item]
|
||||
|
||||
def __iter__(self) -> Iterator[Iterator[T]]:
|
||||
@ -185,7 +180,7 @@ class Tee(Generic[T]):
|
||||
safetee = Tee
|
||||
|
||||
|
||||
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
"""Utility batching function.
|
||||
|
||||
Args:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user