mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 20:41:52 +00:00
core: Put Python version as a project requirement so it is considered by ruff (#26608)
Ruff doesn't know about the python version in `[tool.poetry.dependencies]`. It can get it from `project.requires-python`. Notes: * poetry seems to have issues getting the python constraints from `requires-python` and using `python` in per dependency constraints. So I had to duplicate the info. I will open an issue on poetry. * `inspect.isclass()` doesn't work correctly with `GenericAlias` (`list[...]`, `dict[..., ...]`) on Python <3.11 so I added some `not isinstance(type, GenericAlias)` checks: Python 3.11 ```pycon >>> import inspect >>> inspect.isclass(list) True >>> inspect.isclass(list[str]) False ``` Python 3.9 ```pycon >>> import inspect >>> inspect.isclass(list) True >>> inspect.isclass(list[str]) True ``` Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
0f07cf61da
commit
a47b332841
@ -14,7 +14,8 @@ import contextlib
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Generator, Type, TypeVar, Union, cast
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Callable, TypeVar, Union, cast
|
||||||
|
|
||||||
from langchain_core._api.internal import is_caller_internal
|
from langchain_core._api.internal import is_caller_internal
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning):
|
|||||||
# PUBLIC API
|
# PUBLIC API
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Union[Callable[..., Any], Type])
|
T = TypeVar("T", bound=Union[Callable[..., Any], type])
|
||||||
|
|
||||||
|
|
||||||
def beta(
|
def beta(
|
||||||
|
@ -14,11 +14,10 @@ import contextlib
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -41,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
|
|||||||
|
|
||||||
|
|
||||||
# Last Any should be FieldInfoV1 but this leads to circular imports
|
# Last Any should be FieldInfoV1 but this leads to circular imports
|
||||||
T = TypeVar("T", bound=Union[Type, Callable[..., Any], Any])
|
T = TypeVar("T", bound=Union[type, Callable[..., Any], Any])
|
||||||
|
|
||||||
|
|
||||||
def _validate_deprecation_params(
|
def _validate_deprecation_params(
|
||||||
@ -262,7 +261,7 @@ def deprecated(
|
|||||||
if not _obj_type:
|
if not _obj_type:
|
||||||
_obj_type = "attribute"
|
_obj_type = "attribute"
|
||||||
wrapped = None
|
wrapped = None
|
||||||
_name = _name or cast(Union[Type, Callable], obj.fget).__qualname__
|
_name = _name or cast(Union[type, Callable], obj.fget).__qualname__
|
||||||
old_doc = obj.__doc__
|
old_doc = obj.__doc__
|
||||||
|
|
||||||
class _deprecated_property(property):
|
class _deprecated_property(property):
|
||||||
@ -304,7 +303,7 @@ def deprecated(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_name = _name or cast(Union[Type, Callable], obj).__qualname__
|
_name = _name or cast(Union[type, Callable], obj).__qualname__
|
||||||
if not _obj_type:
|
if not _obj_type:
|
||||||
# edge case: when a function is within another function
|
# edge case: when a function is within another function
|
||||||
# within a test, this will call it a "method" not a "function"
|
# within a test, this will call it a "method" not a "function"
|
||||||
|
@ -25,7 +25,8 @@ The schemas for the agents themselves are defined in langchain.agents.agent.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -30,7 +24,7 @@ from langchain_core.runnables.config import RunnableConfig, ensure_config, patch
|
|||||||
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
Values = Dict[Union[asyncio.Event, threading.Event], Any]
|
Values = dict[Union[asyncio.Event, threading.Event], Any]
|
||||||
CONTEXT_CONFIG_PREFIX = "__context__/"
|
CONTEXT_CONFIG_PREFIX = "__context__/"
|
||||||
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
CONTEXT_CONFIG_SUFFIX_GET = "/get"
|
||||||
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
CONTEXT_CONFIG_SUFFIX_SET = "/set"
|
||||||
@ -70,10 +64,10 @@ def _key_from_id(id_: str) -> str:
|
|||||||
|
|
||||||
def _config_with_context(
|
def _config_with_context(
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
steps: List[Runnable],
|
steps: list[Runnable],
|
||||||
setter: Callable,
|
setter: Callable,
|
||||||
getter: Callable,
|
getter: Callable,
|
||||||
event_cls: Union[Type[threading.Event], Type[asyncio.Event]],
|
event_cls: Union[type[threading.Event], type[asyncio.Event]],
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
if any(k.startswith(CONTEXT_CONFIG_PREFIX) for k in config.get("configurable", {})):
|
||||||
return config
|
return config
|
||||||
@ -99,10 +93,10 @@ def _config_with_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
values: Values = {}
|
values: Values = {}
|
||||||
events: DefaultDict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
events: defaultdict[str, Union[asyncio.Event, threading.Event]] = defaultdict(
|
||||||
event_cls
|
event_cls
|
||||||
)
|
)
|
||||||
context_funcs: Dict[str, Callable[[], Any]] = {}
|
context_funcs: dict[str, Callable[[], Any]] = {}
|
||||||
for key, group in grouped_by_key.items():
|
for key, group in grouped_by_key.items():
|
||||||
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
getters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_GET)]
|
||||||
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
setters = [s for s in group if s[0].id.endswith(CONTEXT_CONFIG_SUFFIX_SET)]
|
||||||
@ -129,7 +123,7 @@ def _config_with_context(
|
|||||||
|
|
||||||
def aconfig_with_context(
|
def aconfig_with_context(
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
steps: List[Runnable],
|
steps: list[Runnable],
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
"""Asynchronously patch a runnable config with context getters and setters.
|
"""Asynchronously patch a runnable config with context getters and setters.
|
||||||
|
|
||||||
@ -145,7 +139,7 @@ def aconfig_with_context(
|
|||||||
|
|
||||||
def config_with_context(
|
def config_with_context(
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
steps: List[Runnable],
|
steps: list[Runnable],
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
"""Patch a runnable config with context getters and setters.
|
"""Patch a runnable config with context getters and setters.
|
||||||
|
|
||||||
@ -165,13 +159,13 @@ class ContextGet(RunnableSerializable):
|
|||||||
|
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
|
|
||||||
key: Union[str, List[str]]
|
key: Union[str, list[str]]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"ContextGet({_print_keys(self.key)})"
|
return f"ContextGet({_print_keys(self.key)})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ids(self) -> List[str]:
|
def ids(self) -> list[str]:
|
||||||
prefix = self.prefix + "/" if self.prefix else ""
|
prefix = self.prefix + "/" if self.prefix else ""
|
||||||
keys = self.key if isinstance(self.key, list) else [self.key]
|
keys = self.key if isinstance(self.key, list) else [self.key]
|
||||||
return [
|
return [
|
||||||
@ -180,7 +174,7 @@ class ContextGet(RunnableSerializable):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
return super().config_specs + [
|
return super().config_specs + [
|
||||||
ConfigurableFieldSpec(
|
ConfigurableFieldSpec(
|
||||||
id=id_,
|
id=id_,
|
||||||
@ -256,7 +250,7 @@ class ContextSet(RunnableSerializable):
|
|||||||
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
return f"ContextSet({_print_keys(list(self.keys.keys()))})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ids(self) -> List[str]:
|
def ids(self) -> list[str]:
|
||||||
prefix = self.prefix + "/" if self.prefix else ""
|
prefix = self.prefix + "/" if self.prefix else ""
|
||||||
return [
|
return [
|
||||||
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
|
||||||
@ -264,7 +258,7 @@ class ContextSet(RunnableSerializable):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
mapper_config_specs = [
|
mapper_config_specs = [
|
||||||
s
|
s
|
||||||
for mapper in self.keys.values()
|
for mapper in self.keys.values()
|
||||||
@ -364,7 +358,7 @@ class Context:
|
|||||||
return PrefixContext(prefix=scope)
|
return PrefixContext(prefix=scope)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getter(key: Union[str, List[str]], /) -> ContextGet:
|
def getter(key: Union[str, list[str]], /) -> ContextGet:
|
||||||
return ContextGet(key=key)
|
return ContextGet(key=key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -385,7 +379,7 @@ class PrefixContext:
|
|||||||
def __init__(self, prefix: str = ""):
|
def __init__(self, prefix: str = ""):
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def getter(self, key: Union[str, List[str]], /) -> ContextGet:
|
def getter(self, key: Union[str, list[str]], /) -> ContextGet:
|
||||||
return ContextGet(key=key, prefix=self.prefix)
|
return ContextGet(key=key, prefix=self.prefix)
|
||||||
|
|
||||||
def setter(
|
def setter(
|
||||||
|
@ -23,7 +23,8 @@ Cache directly competes with Memory. See documentation for Pros and Cons.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from tenacity import RetryCallState
|
from tenacity import RetryCallState
|
||||||
@ -1070,4 +1071,4 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
self.inheritable_metadata.pop(key)
|
self.inheritable_metadata.pop(key)
|
||||||
|
|
||||||
|
|
||||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||||
|
@ -5,19 +5,15 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncGenerator, Coroutine, Generator, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Generator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -2352,7 +2348,7 @@ def _configure(
|
|||||||
and handler_class is not None
|
and handler_class is not None
|
||||||
)
|
)
|
||||||
if var.get() is not None or create_one:
|
if var.get() is not None or create_one:
|
||||||
var_handler = var.get() or cast(Type[BaseCallbackHandler], handler_class)()
|
var_handler = var.get() or cast(type[BaseCallbackHandler], handler_class)()
|
||||||
if handler_class is None:
|
if handler_class is None:
|
||||||
if not any(
|
if not any(
|
||||||
handler is var_handler # direct pointer comparison
|
handler is var_handler # direct pointer comparison
|
||||||
|
@ -18,7 +18,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Iterator, List
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from langchain_core.chat_sessions import ChatSession
|
from langchain_core.chat_sessions import ChatSession
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ class BaseChatLoader(ABC):
|
|||||||
An iterator of chat sessions.
|
An iterator of chat sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self) -> List[ChatSession]:
|
def load(self) -> list[ChatSession]:
|
||||||
"""Eagerly load the chat sessions into memory.
|
"""Eagerly load the chat sessions into memory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""**Chat Sessions** are a collection of messages and function calls."""
|
"""**Chat Sessions** are a collection of messages and function calls."""
|
||||||
|
|
||||||
from typing import Sequence, TypedDict
|
from collections.abc import Sequence
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
@ -8,7 +8,7 @@ In addition, content loading code should provide a lazy loading interface by def
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
# Re-export Blob and PathLike for backwards compatibility
|
# Re-export Blob and PathLike for backwards compatibility
|
||||||
from langchain_core.documents.base import Blob as Blob
|
from langchain_core.documents.base import Blob as Blob
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Iterator, Optional, Sequence, Union
|
from collections.abc import Iterator, Sequence
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
from langsmith import Client as LangSmithClient
|
from langsmith import Client as LangSmithClient
|
||||||
|
|
||||||
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from collections.abc import Generator
|
||||||
from io import BufferedReader, BytesIO
|
from io import BufferedReader, BytesIO
|
||||||
from pathlib import PurePath
|
from pathlib import PurePath
|
||||||
from typing import Any, Generator, Literal, Optional, Union, cast
|
from typing import Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""**Embeddings** interface."""
|
"""**Embeddings** interface."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ class Embeddings(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Embed search docs.
|
"""Embed search docs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -46,7 +45,7 @@ class Embeddings(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -56,7 +55,7 @@ class Embeddings(ABC):
|
|||||||
Embedding.
|
Embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Asynchronous Embed search docs.
|
"""Asynchronous Embed search docs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -67,7 +66,7 @@ class Embeddings(ABC):
|
|||||||
"""
|
"""
|
||||||
return await run_in_executor(None, self.embed_documents, texts)
|
return await run_in_executor(None, self.embed_documents, texts)
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> list[float]:
|
||||||
"""Asynchronous Embed query text.
|
"""Asynchronous Embed query text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# Please do not add additional fake embedding model implementations here.
|
# Please do not add additional fake embedding model implementations here.
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -51,15 +50,15 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
|||||||
size: int
|
size: int
|
||||||
"""The size of the embedding vector."""
|
"""The size of the embedding vector."""
|
||||||
|
|
||||||
def _get_embedding(self) -> List[float]:
|
def _get_embedding(self) -> list[float]:
|
||||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||||
|
|
||||||
return list(np.random.normal(size=self.size))
|
return list(np.random.normal(size=self.size))
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return [self._get_embedding() for _ in texts]
|
return [self._get_embedding() for _ in texts]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
return self._get_embedding()
|
return self._get_embedding()
|
||||||
|
|
||||||
|
|
||||||
@ -106,7 +105,7 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
|||||||
size: int
|
size: int
|
||||||
"""The size of the embedding vector."""
|
"""The size of the embedding vector."""
|
||||||
|
|
||||||
def _get_embedding(self, seed: int) -> List[float]:
|
def _get_embedding(self, seed: int) -> list[float]:
|
||||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||||
|
|
||||||
# set the seed for the random generator
|
# set the seed for the random generator
|
||||||
@ -117,8 +116,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
|||||||
"""Get a seed for the random generator, using the hash of the text."""
|
"""Get a seed for the random generator, using the hash of the text."""
|
||||||
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
return self._get_embedding(seed=self._get_seed(text))
|
return self._get_embedding(seed=self._get_seed(text))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Interface for selecting examples to include in prompts."""
|
"""Interface for selecting examples to include in prompts."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
@ -10,14 +10,14 @@ class BaseExampleSelector(ABC):
|
|||||||
"""Interface for selecting examples to include in prompts."""
|
"""Interface for selecting examples to include in prompts."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_example(self, example: Dict[str, str]) -> Any:
|
def add_example(self, example: dict[str, str]) -> Any:
|
||||||
"""Add new example to store.
|
"""Add new example to store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
example: A dictionary with keys as input variables
|
example: A dictionary with keys as input variables
|
||||||
and values as their values."""
|
and values as their values."""
|
||||||
|
|
||||||
async def aadd_example(self, example: Dict[str, str]) -> Any:
|
async def aadd_example(self, example: dict[str, str]) -> Any:
|
||||||
"""Async add new example to store.
|
"""Async add new example to store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -27,14 +27,14 @@ class BaseExampleSelector(ABC):
|
|||||||
return await run_in_executor(None, self.add_example, example)
|
return await run_in_executor(None, self.add_example, example)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
"""Select which examples to use based on the inputs.
|
"""Select which examples to use based on the inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_variables: A dictionary with keys as input variables
|
input_variables: A dictionary with keys as input variables
|
||||||
and values as their values."""
|
and values as their values."""
|
||||||
|
|
||||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
"""Async select which examples to use based on the inputs.
|
"""Async select which examples to use based on the inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Select examples based on length."""
|
"""Select examples based on length."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -17,7 +17,7 @@ def _get_length_based(text: str) -> int:
|
|||||||
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
||||||
"""Select examples based on length."""
|
"""Select examples based on length."""
|
||||||
|
|
||||||
examples: List[dict]
|
examples: list[dict]
|
||||||
"""A list of the examples that the prompt template expects."""
|
"""A list of the examples that the prompt template expects."""
|
||||||
|
|
||||||
example_prompt: PromptTemplate
|
example_prompt: PromptTemplate
|
||||||
@ -29,10 +29,10 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
max_length: int = 2048
|
max_length: int = 2048
|
||||||
"""Max length for the prompt, beyond which examples are cut."""
|
"""Max length for the prompt, beyond which examples are cut."""
|
||||||
|
|
||||||
example_text_lengths: List[int] = Field(default_factory=list) # :meta private:
|
example_text_lengths: list[int] = Field(default_factory=list) # :meta private:
|
||||||
"""Length of each example."""
|
"""Length of each example."""
|
||||||
|
|
||||||
def add_example(self, example: Dict[str, str]) -> None:
|
def add_example(self, example: dict[str, str]) -> None:
|
||||||
"""Add new example to list.
|
"""Add new example to list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -43,7 +43,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
string_example = self.example_prompt.format(**example)
|
string_example = self.example_prompt.format(**example)
|
||||||
self.example_text_lengths.append(self.get_text_length(string_example))
|
self.example_text_lengths.append(self.get_text_length(string_example))
|
||||||
|
|
||||||
async def aadd_example(self, example: Dict[str, str]) -> None:
|
async def aadd_example(self, example: dict[str, str]) -> None:
|
||||||
"""Async add new example to list.
|
"""Async add new example to list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -62,7 +62,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
|
self.example_text_lengths = [self.get_text_length(eg) for eg in string_examples]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
"""Select which examples to use based on the input lengths.
|
"""Select which examples to use based on the input lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -86,7 +86,7 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
i += 1
|
i += 1
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||||
"""Async select which examples to use based on the input lengths.
|
"""Async select which examples to use based on the input lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Collection,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Optional,
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, List, Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -41,7 +42,7 @@ METADATA_LINKS_KEY = "links"
|
|||||||
|
|
||||||
|
|
||||||
@beta()
|
@beta()
|
||||||
def get_links(doc: Document) -> List[Link]:
|
def get_links(doc: Document) -> list[Link]:
|
||||||
"""Get the links from a document.
|
"""Get the links from a document.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -5,17 +5,13 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Sequence, TypedDict
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, TypedDict
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Sequence, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -22,7 +23,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
.. versionadded:: 0.2.29
|
.. versionadded:: 0.2.29
|
||||||
"""
|
"""
|
||||||
|
|
||||||
store: Dict[str, Document] = Field(default_factory=dict)
|
store: dict[str, Document] = Field(default_factory=dict)
|
||||||
top_k: int = 4
|
top_k: int = 4
|
||||||
|
|
||||||
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
|
||||||
@ -43,7 +44,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
|
|
||||||
return UpsertResponse(succeeded=ok_ids, failed=[])
|
return UpsertResponse(succeeded=ok_ids, failed=[])
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse:
|
def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse:
|
||||||
"""Delete by ID."""
|
"""Delete by ID."""
|
||||||
if ids is None:
|
if ids is None:
|
||||||
raise ValueError("IDs must be provided for deletion")
|
raise ValueError("IDs must be provided for deletion")
|
||||||
@ -59,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
succeeded=ok_ids, num_deleted=len(ok_ids), num_failed=0, failed=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, ids: Sequence[str], /, **kwargs: Any) -> List[Document]:
|
def get(self, ids: Sequence[str], /, **kwargs: Any) -> list[Document]:
|
||||||
"""Get by ids."""
|
"""Get by ids."""
|
||||||
found_documents = []
|
found_documents = []
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ class InMemoryDocumentIndex(DocumentIndex):
|
|||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
counts_by_doc = []
|
counts_by_doc = []
|
||||||
|
|
||||||
for document in self.store.values():
|
for document in self.store.values():
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import lru_cache
|
from collections.abc import Mapping, Sequence
|
||||||
|
from functools import cache
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -52,7 +50,7 @@ class LangSmithParams(TypedDict, total=False):
|
|||||||
"""Stop words for generation."""
|
"""Stop words for generation."""
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
@cache # Cache the tokenizer
|
||||||
def get_tokenizer() -> Any:
|
def get_tokenizer() -> Any:
|
||||||
"""Get a GPT-2 tokenizer instance.
|
"""Get a GPT-2 tokenizer instance.
|
||||||
|
|
||||||
@ -158,7 +156,7 @@ class BaseLanguageModel(
|
|||||||
return Union[
|
return Union[
|
||||||
str,
|
str,
|
||||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||||
List[AnyMessage],
|
list[AnyMessage],
|
||||||
]
|
]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -3,21 +3,19 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -1121,18 +1119,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
tools: Sequence[Union[typing.Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self,
|
self,
|
||||||
schema: Union[Dict, type], # noqa: UP006
|
schema: Union[typing.Dict, type], # noqa: UP006
|
||||||
*,
|
*,
|
||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006
|
) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -14,7 +15,7 @@ from langchain_core.runnables import RunnableConfig
|
|||||||
class FakeListLLM(LLM):
|
class FakeListLLM(LLM):
|
||||||
"""Fake LLM for testing purposes."""
|
"""Fake LLM for testing purposes."""
|
||||||
|
|
||||||
responses: List[str]
|
responses: list[str]
|
||||||
"""List of responses to return in order."""
|
"""List of responses to return in order."""
|
||||||
# This parameter should be removed from FakeListLLM since
|
# This parameter should be removed from FakeListLLM since
|
||||||
# it's only used by sub-classes.
|
# it's only used by sub-classes.
|
||||||
@ -37,7 +38,7 @@ class FakeListLLM(LLM):
|
|||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -52,7 +53,7 @@ class FakeListLLM(LLM):
|
|||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -90,7 +91,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
result = self.invoke(input, config)
|
result = self.invoke(input, config)
|
||||||
@ -110,7 +111,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
result = await self.ainvoke(input, config)
|
result = await self.ainvoke(input, config)
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -17,7 +18,7 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResu
|
|||||||
class FakeMessagesListChatModel(BaseChatModel):
|
class FakeMessagesListChatModel(BaseChatModel):
|
||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
responses: List[BaseMessage]
|
responses: list[BaseMessage]
|
||||||
"""List of responses to **cycle** through in order."""
|
"""List of responses to **cycle** through in order."""
|
||||||
sleep: Optional[float] = None
|
sleep: Optional[float] = None
|
||||||
"""Sleep time in seconds between responses."""
|
"""Sleep time in seconds between responses."""
|
||||||
@ -26,8 +27,8 @@ class FakeMessagesListChatModel(BaseChatModel):
|
|||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
@ -51,7 +52,7 @@ class FakeListChatModelError(Exception):
|
|||||||
class FakeListChatModel(SimpleChatModel):
|
class FakeListChatModel(SimpleChatModel):
|
||||||
"""Fake ChatModel for testing purposes."""
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
responses: List[str]
|
responses: list[str]
|
||||||
"""List of responses to **cycle** through in order."""
|
"""List of responses to **cycle** through in order."""
|
||||||
sleep: Optional[float] = None
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
@ -65,8 +66,8 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -80,8 +81,8 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Union[List[str], None] = None,
|
stop: Union[list[str], None] = None,
|
||||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
@ -103,8 +104,8 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Union[List[str], None] = None,
|
stop: Union[list[str], None] = None,
|
||||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
@ -124,7 +125,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
return {"responses": self.responses}
|
return {"responses": self.responses}
|
||||||
|
|
||||||
|
|
||||||
@ -133,8 +134,8 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -142,8 +143,8 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
@ -157,7 +158,7 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
return "fake-chat-model"
|
return "fake-chat-model"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
return {"key": "fake"}
|
return {"key": "fake"}
|
||||||
|
|
||||||
|
|
||||||
@ -186,8 +187,8 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
@ -202,8 +203,8 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
@ -231,7 +232,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
# Use a regular expression to split on whitespace with a capture group
|
# Use a regular expression to split on whitespace with a capture group
|
||||||
# so that we can preserve the whitespace in the output.
|
# so that we can preserve the whitespace in the output.
|
||||||
assert isinstance(content, str)
|
assert isinstance(content, str)
|
||||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
content_chunks = cast(list[str], re.split(r"(\s)", content))
|
||||||
|
|
||||||
for token in content_chunks:
|
for token in content_chunks:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
@ -249,7 +250,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
for fkey, fvalue in value.items():
|
for fkey, fvalue in value.items():
|
||||||
if isinstance(fvalue, str):
|
if isinstance(fvalue, str):
|
||||||
# Break function call by `,`
|
# Break function call by `,`
|
||||||
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
|
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
|
||||||
for fvalue_chunk in fvalue_chunks:
|
for fvalue_chunk in fvalue_chunks:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
@ -306,8 +307,8 @@ class ParrotFakeChatModel(BaseChatModel):
|
|||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
@ -10,16 +10,12 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -448,7 +444,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return [g[0].text for g in llm_result.generations]
|
return [g[0].text for g in llm_result.generations]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[str], [e for _ in inputs])
|
return cast(list[str], [e for _ in inputs])
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -494,7 +490,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return [g[0].text for g in llm_result.generations]
|
return [g[0].text for g in llm_result.generations]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[str], [e for _ in inputs])
|
return cast(list[str], [e for _ in inputs])
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -883,13 +879,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
assert run_name is None or (
|
assert run_name is None or (
|
||||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||||
)
|
)
|
||||||
callbacks = cast(List[Callbacks], callbacks)
|
callbacks = cast(list[Callbacks], callbacks)
|
||||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||||
metadata_list = cast(
|
metadata_list = cast(
|
||||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||||
)
|
)
|
||||||
run_name_list = run_name or cast(
|
run_name_list = run_name or cast(
|
||||||
List[Optional[str]], ([None] * len(prompts))
|
list[Optional[str]], ([None] * len(prompts))
|
||||||
)
|
)
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
CallbackManager.configure(
|
CallbackManager.configure(
|
||||||
@ -910,9 +906,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
cast(Callbacks, callbacks),
|
cast(Callbacks, callbacks),
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
self.verbose,
|
self.verbose,
|
||||||
cast(List[str], tags),
|
cast(list[str], tags),
|
||||||
self.tags,
|
self.tags,
|
||||||
cast(Dict[str, Any], metadata),
|
cast(dict[str, Any], metadata),
|
||||||
self.metadata,
|
self.metadata,
|
||||||
)
|
)
|
||||||
] * len(prompts)
|
] * len(prompts)
|
||||||
@ -1116,13 +1112,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
assert run_name is None or (
|
assert run_name is None or (
|
||||||
isinstance(run_name, list) and len(run_name) == len(prompts)
|
isinstance(run_name, list) and len(run_name) == len(prompts)
|
||||||
)
|
)
|
||||||
callbacks = cast(List[Callbacks], callbacks)
|
callbacks = cast(list[Callbacks], callbacks)
|
||||||
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
|
tags_list = cast(list[Optional[list[str]]], tags or ([None] * len(prompts)))
|
||||||
metadata_list = cast(
|
metadata_list = cast(
|
||||||
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
|
list[Optional[dict[str, Any]]], metadata or ([{}] * len(prompts))
|
||||||
)
|
)
|
||||||
run_name_list = run_name or cast(
|
run_name_list = run_name or cast(
|
||||||
List[Optional[str]], ([None] * len(prompts))
|
list[Optional[str]], ([None] * len(prompts))
|
||||||
)
|
)
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
AsyncCallbackManager.configure(
|
AsyncCallbackManager.configure(
|
||||||
@ -1143,9 +1139,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
cast(Callbacks, callbacks),
|
cast(Callbacks, callbacks),
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
self.verbose,
|
self.verbose,
|
||||||
cast(List[str], tags),
|
cast(list[str], tags),
|
||||||
self.tags,
|
self.tags,
|
||||||
cast(Dict[str, Any], metadata),
|
cast(dict[str, Any], metadata),
|
||||||
self.metadata,
|
self.metadata,
|
||||||
)
|
)
|
||||||
] * len(prompts)
|
] * len(prompts)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.load.mapping import (
|
from langchain_core.load.mapping import (
|
||||||
@ -34,11 +34,11 @@ class Reviver:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[list[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
additional_import_mappings: Optional[
|
additional_import_mappings: Optional[
|
||||||
Dict[Tuple[str, ...], Tuple[str, ...]]
|
dict[tuple[str, ...], tuple[str, ...]]
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the reviver.
|
"""Initialize the reviver.
|
||||||
@ -73,7 +73,7 @@ class Reviver:
|
|||||||
else ALL_SERIALIZABLE_MAPPINGS
|
else ALL_SERIALIZABLE_MAPPINGS
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
def __call__(self, value: dict[str, Any]) -> Any:
|
||||||
if (
|
if (
|
||||||
value.get("lc", None) == 1
|
value.get("lc", None) == 1
|
||||||
and value.get("type", None) == "secret"
|
and value.get("type", None) == "secret"
|
||||||
@ -154,10 +154,10 @@ class Reviver:
|
|||||||
def loads(
|
def loads(
|
||||||
text: str,
|
text: str,
|
||||||
*,
|
*,
|
||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[list[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Revive a LangChain class from a JSON string.
|
"""Revive a LangChain class from a JSON string.
|
||||||
Equivalent to `load(json.loads(text))`.
|
Equivalent to `load(json.loads(text))`.
|
||||||
@ -190,10 +190,10 @@ def loads(
|
|||||||
def load(
|
def load(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
*,
|
*,
|
||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[list[str]] = None,
|
||||||
secrets_from_env: bool = True,
|
secrets_from_env: bool = True,
|
||||||
additional_import_mappings: Optional[Dict[Tuple[str, ...], Tuple[str, ...]]] = None,
|
additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||||
|
@ -18,11 +18,9 @@ The mapping allows us to deserialize an AIMessage created with an older
|
|||||||
version of LangChain where the code was in a different location.
|
version of LangChain where the code was in a different location.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
# First value is the value that it is serialized as
|
# First value is the value that it is serialized as
|
||||||
# Second value is the path to load it from
|
# Second value is the path to load it from
|
||||||
SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||||
("langchain", "schema", "messages", "AIMessage"): (
|
("langchain", "schema", "messages", "AIMessage"): (
|
||||||
"langchain_core",
|
"langchain_core",
|
||||||
"messages",
|
"messages",
|
||||||
@ -535,7 +533,7 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
|||||||
|
|
||||||
# Needed for backwards compatibility for old versions of LangChain where things
|
# Needed for backwards compatibility for old versions of LangChain where things
|
||||||
# Were in different place
|
# Were in different place
|
||||||
_OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
_OG_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||||
("langchain", "schema", "AIMessage"): (
|
("langchain", "schema", "AIMessage"): (
|
||||||
"langchain_core",
|
"langchain_core",
|
||||||
"messages",
|
"messages",
|
||||||
@ -583,7 +581,7 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
|||||||
|
|
||||||
# Needed for backwards compatibility for a few versions where we serialized
|
# Needed for backwards compatibility for a few versions where we serialized
|
||||||
# with langchain_core paths.
|
# with langchain_core paths.
|
||||||
OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
OLD_CORE_NAMESPACES_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||||
("langchain_core", "messages", "ai", "AIMessage"): (
|
("langchain_core", "messages", "ai", "AIMessage"): (
|
||||||
"langchain_core",
|
"langchain_core",
|
||||||
"messages",
|
"messages",
|
||||||
@ -937,7 +935,7 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
_JS_SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||||
("langchain_core", "messages", "AIMessage"): (
|
("langchain_core", "messages", "AIMessage"): (
|
||||||
"langchain_core",
|
"langchain_core",
|
||||||
"messages",
|
"messages",
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
@ -25,9 +23,9 @@ class BaseSerialized(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
lc: int
|
lc: int
|
||||||
id: List[str]
|
id: list[str]
|
||||||
name: NotRequired[str]
|
name: NotRequired[str]
|
||||||
graph: NotRequired[Dict[str, Any]]
|
graph: NotRequired[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class SerializedConstructor(BaseSerialized):
|
class SerializedConstructor(BaseSerialized):
|
||||||
@ -39,7 +37,7 @@ class SerializedConstructor(BaseSerialized):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["constructor"]
|
type: Literal["constructor"]
|
||||||
kwargs: Dict[str, Any]
|
kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class SerializedSecret(BaseSerialized):
|
class SerializedSecret(BaseSerialized):
|
||||||
@ -125,7 +123,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
|
|
||||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||||
@ -134,7 +132,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
return cls.__module__.split(".")
|
return cls.__module__.split(".")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> dict[str, str]:
|
||||||
"""A map of constructor argument names to secret ids.
|
"""A map of constructor argument names to secret ids.
|
||||||
|
|
||||||
For example,
|
For example,
|
||||||
@ -143,7 +141,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
return dict()
|
return dict()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_attributes(self) -> Dict:
|
def lc_attributes(self) -> dict:
|
||||||
"""List of attribute names that should be included in the serialized kwargs.
|
"""List of attribute names that should be included in the serialized kwargs.
|
||||||
|
|
||||||
These attributes must be accepted by the constructor.
|
These attributes must be accepted by the constructor.
|
||||||
@ -152,7 +150,7 @@ class Serializable(BaseModel, ABC):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def lc_id(cls) -> List[str]:
|
def lc_id(cls) -> list[str]:
|
||||||
"""A unique identifier for this class for serialization purposes.
|
"""A unique identifier for this class for serialization purposes.
|
||||||
|
|
||||||
The unique identifier is a list of strings that describes the path
|
The unique identifier is a list of strings that describes the path
|
||||||
@ -315,8 +313,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _replace_secrets(
|
def _replace_secrets(
|
||||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
root: dict[Any, Any], secrets_map: dict[str, str]
|
||||||
) -> Dict[Any, Any]:
|
) -> dict[Any, Any]:
|
||||||
result = root.copy()
|
result = root.copy()
|
||||||
for path, secret_id in secrets_map.items():
|
for path, secret_id in secrets_map.items():
|
||||||
[*parts, last] = path.split(".")
|
[*parts, last] = path.split(".")
|
||||||
@ -344,7 +342,7 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
|||||||
Returns:
|
Returns:
|
||||||
SerializedNotImplemented
|
SerializedNotImplemented
|
||||||
"""
|
"""
|
||||||
_id: List[str] = []
|
_id: list[str] = []
|
||||||
try:
|
try:
|
||||||
if hasattr(obj, "__name__"):
|
if hasattr(obj, "__name__"):
|
||||||
_id = [*obj.__module__.split("."), obj.__name__]
|
_id = [*obj.__module__.split("."), obj.__name__]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
from typing_extensions import Self, TypedDict
|
from typing_extensions import Self, TypedDict
|
||||||
@ -69,9 +69,9 @@ class AIMessage(BaseMessage):
|
|||||||
At the moment, this is ignored by most models. Usage is discouraged.
|
At the moment, this is ignored by most models. Usage is discouraged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_calls: List[ToolCall] = []
|
tool_calls: list[ToolCall] = []
|
||||||
"""If provided, tool calls associated with the message."""
|
"""If provided, tool calls associated with the message."""
|
||||||
invalid_tool_calls: List[InvalidToolCall] = []
|
invalid_tool_calls: list[InvalidToolCall] = []
|
||||||
"""If provided, tool calls with parsing errors associated with the message."""
|
"""If provided, tool calls with parsing errors associated with the message."""
|
||||||
usage_metadata: Optional[UsageMetadata] = None
|
usage_metadata: Optional[UsageMetadata] = None
|
||||||
"""If provided, usage metadata for a message, such as token counts.
|
"""If provided, usage metadata for a message, such as token counts.
|
||||||
@ -83,7 +83,7 @@ class AIMessage(BaseMessage):
|
|||||||
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
"""The type of the message (used for deserialization). Defaults to "ai"."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Pass in content as positional arg.
|
"""Pass in content as positional arg.
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class AIMessage(BaseMessage):
|
|||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -104,7 +104,7 @@ class AIMessage(BaseMessage):
|
|||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_attributes(self) -> Dict:
|
def lc_attributes(self) -> dict:
|
||||||
"""Attrs to be serialized even if they are derived from other init args."""
|
"""Attrs to be serialized even if they are derived from other init args."""
|
||||||
return {
|
return {
|
||||||
"tool_calls": self.tool_calls,
|
"tool_calls": self.tool_calls,
|
||||||
@ -137,7 +137,7 @@ class AIMessage(BaseMessage):
|
|||||||
|
|
||||||
# Ensure "type" is properly set on all tool call-like dicts.
|
# Ensure "type" is properly set on all tool call-like dicts.
|
||||||
if tool_calls := values.get("tool_calls"):
|
if tool_calls := values.get("tool_calls"):
|
||||||
updated: List = []
|
updated: list = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
updated.append(
|
updated.append(
|
||||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||||
@ -178,7 +178,7 @@ class AIMessage(BaseMessage):
|
|||||||
base = super().pretty_repr(html=html)
|
base = super().pretty_repr(html=html)
|
||||||
lines = []
|
lines = []
|
||||||
|
|
||||||
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
|
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]:
|
||||||
lines = [
|
lines = [
|
||||||
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
|
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
|
||||||
f" Call ID: {tc.get('id')}",
|
f" Call ID: {tc.get('id')}",
|
||||||
@ -218,11 +218,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
"""The type of the message (used for deserialization).
|
"""The type of the message (used for deserialization).
|
||||||
Defaults to "AIMessageChunk"."""
|
Defaults to "AIMessageChunk"."""
|
||||||
|
|
||||||
tool_call_chunks: List[ToolCallChunk] = []
|
tool_call_chunks: list[ToolCallChunk] = []
|
||||||
"""If provided, tool call chunks associated with the message."""
|
"""If provided, tool call chunks associated with the message."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -232,7 +232,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_attributes(self) -> Dict:
|
def lc_attributes(self) -> dict:
|
||||||
"""Attrs to be serialized even if they are derived from other init args."""
|
"""Attrs to be serialized even if they are derived from other init args."""
|
||||||
return {
|
return {
|
||||||
"tool_calls": self.tool_calls,
|
"tool_calls": self.tool_calls,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
from pydantic import ConfigDict, Field, field_validator
|
from pydantic import ConfigDict, Field, field_validator
|
||||||
|
|
||||||
@ -143,7 +144,7 @@ def merge_content(
|
|||||||
merged = [merged] + content # type: ignore
|
merged = [merged] + content # type: ignore
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# If both are lists
|
# If both are lists
|
||||||
merged = merge_lists(cast(List, merged), content) # type: ignore
|
merged = merge_lists(cast(list, merged), content) # type: ignore
|
||||||
# If the first content is a list, and the second content is a string
|
# If the first content is a list, and the second content is a string
|
||||||
else:
|
else:
|
||||||
# If the last element of the first content is a string
|
# If the last element of the first content is a string
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -18,7 +18,7 @@ class ChatMessage(BaseMessage):
|
|||||||
"""The type of the message (used during serialization). Defaults to "chat"."""
|
"""The type of the message (used during serialization). Defaults to "chat"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"].
|
Default is ["langchain", "schema", "messages"].
|
||||||
"""
|
"""
|
||||||
@ -39,7 +39,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|||||||
Defaults to "ChatMessageChunk"."""
|
Defaults to "ChatMessageChunk"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"].
|
Default is ["langchain", "schema", "messages"].
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -26,7 +26,7 @@ class FunctionMessage(BaseMessage):
|
|||||||
"""The type of the message (used for serialization). Defaults to "function"."""
|
"""The type of the message (used for serialization). Defaults to "function"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
@ -46,7 +46,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
|||||||
Defaults to "FunctionMessageChunk"."""
|
Defaults to "FunctionMessageChunk"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
@ -39,13 +39,13 @@ class HumanMessage(BaseMessage):
|
|||||||
"""The type of the message (used for serialization). Defaults to "human"."""
|
"""The type of the message (used for serialization). Defaults to "human"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Pass in content as positional arg.
|
"""Pass in content as positional arg.
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
|||||||
Defaults to "HumanMessageChunk"."""
|
Defaults to "HumanMessageChunk"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from langchain_core.messages.base import BaseMessage
|
from langchain_core.messages.base import BaseMessage
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class RemoveMessage(BaseMessage):
|
|||||||
return super().__init__("", id=id, **kwargs)
|
return super().__init__("", id=id, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
|
|
||||||
@ -33,13 +33,13 @@ class SystemMessage(BaseMessage):
|
|||||||
"""The type of the message (used for serialization). Defaults to "system"."""
|
"""The type of the message (used for serialization). Defaults to "system"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Pass in content as positional arg.
|
"""Pass in content as positional arg.
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
|||||||
Defaults to "SystemMessageChunk"."""
|
Defaults to "SystemMessageChunk"."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
@ -78,7 +78,7 @@ class ToolMessage(BaseMessage):
|
|||||||
"""Currently inherited from BaseMessage, but not used."""
|
"""Currently inherited from BaseMessage, but not used."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
Default is ["langchain", "schema", "messages"]."""
|
Default is ["langchain", "schema", "messages"]."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
@ -123,7 +123,7 @@ class ToolMessage(BaseMessage):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any
|
self, content: Union[str, list[Union[str, dict]]], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
|||||||
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
@ -187,7 +187,7 @@ class ToolCall(TypedDict):
|
|||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""The name of the tool to be called."""
|
"""The name of the tool to be called."""
|
||||||
args: Dict[str, Any]
|
args: dict[str, Any]
|
||||||
"""The arguments to the tool call."""
|
"""The arguments to the tool call."""
|
||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
"""An identifier associated with the tool call.
|
"""An identifier associated with the tool call.
|
||||||
@ -198,7 +198,7 @@ class ToolCall(TypedDict):
|
|||||||
type: NotRequired[Literal["tool_call"]]
|
type: NotRequired[Literal["tool_call"]]
|
||||||
|
|
||||||
|
|
||||||
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
def tool_call(*, name: str, args: dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||||
|
|
||||||
|
|
||||||
@ -276,8 +276,8 @@ def invalid_tool_call(
|
|||||||
|
|
||||||
|
|
||||||
def default_tool_parser(
|
def default_tool_parser(
|
||||||
raw_tool_calls: List[dict],
|
raw_tool_calls: list[dict],
|
||||||
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
|
) -> tuple[list[ToolCall], list[InvalidToolCall]]:
|
||||||
"""Best-effort parsing of tools."""
|
"""Best-effort parsing of tools."""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
invalid_tool_calls = []
|
invalid_tool_calls = []
|
||||||
@ -306,7 +306,7 @@ def default_tool_parser(
|
|||||||
return tool_calls, invalid_tool_calls
|
return tool_calls, invalid_tool_calls
|
||||||
|
|
||||||
|
|
||||||
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
|
def default_tool_chunk_parser(raw_tool_calls: list[dict]) -> list[ToolCallChunk]:
|
||||||
"""Best-effort parsing of tool chunks."""
|
"""Best-effort parsing of tool chunks."""
|
||||||
tool_call_chunks = []
|
tool_call_chunks = []
|
||||||
for tool_call in raw_tool_calls:
|
for tool_call in raw_tool_calls:
|
||||||
|
@ -11,25 +11,21 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Discriminator, Field, Tag
|
from pydantic import Discriminator, Field, Tag
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
from langchain_core.messages.ai import AIMessage, AIMessageChunk
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
|
||||||
@ -198,7 +194,7 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
|||||||
|
|
||||||
|
|
||||||
MessageLikeRepresentation = Union[
|
MessageLikeRepresentation = Union[
|
||||||
BaseMessage, List[str], Tuple[str, str], str, Dict[str, Any]
|
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,12 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Optional, TypeVar, Union
|
from typing import Annotated, Any, Optional, TypeVar, Union
|
||||||
|
|
||||||
import jsonpatch # type: ignore[import]
|
import jsonpatch # type: ignore[import]
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
||||||
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import re
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import AsyncIterator, Iterator, List, TypeVar, Union
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import Optional as Optional
|
from typing import Optional as Optional
|
||||||
|
from typing import TypeVar, Union
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
@ -29,7 +30,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
|
|||||||
yield buffer.popleft()
|
yield buffer.popleft()
|
||||||
|
|
||||||
|
|
||||||
class ListOutputParser(BaseTransformOutputParser[List[str]]):
|
class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
||||||
"""Parse the output of an LLM call to a list."""
|
"""Parse the output of an LLM call to a list."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from types import GenericAlias
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import jsonpatch # type: ignore[import]
|
import jsonpatch # type: ignore[import]
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -20,7 +21,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|||||||
args_only: bool = True
|
args_only: bool = True
|
||||||
"""Whether to only return the arguments to the function call."""
|
"""Whether to only return the arguments to the function call."""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -72,7 +73,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
||||||
return jsonpatch.make_patch(prev, next).patch
|
return jsonpatch.make_patch(prev, next).patch
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -166,7 +167,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
|||||||
key_name: str
|
key_name: str
|
||||||
"""The name of the key to return."""
|
"""The name of the key to return."""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -223,7 +224,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
result = parser.parse_result([chat_generation])
|
result = parser.parse_result([chat_generation])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]]
|
||||||
"""The pydantic schema to parse the output with.
|
"""The pydantic schema to parse the output with.
|
||||||
|
|
||||||
If multiple schemas are provided, then the function name will be used to
|
If multiple schemas are provided, then the function name will be used to
|
||||||
@ -232,7 +233,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_schema(cls, values: Dict) -> Any:
|
def validate_schema(cls, values: dict) -> Any:
|
||||||
"""Validate the pydantic schema.
|
"""Validate the pydantic schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -246,17 +247,19 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|||||||
"""
|
"""
|
||||||
schema = values["pydantic_schema"]
|
schema = values["pydantic_schema"]
|
||||||
if "args_only" not in values:
|
if "args_only" not in values:
|
||||||
values["args_only"] = isinstance(schema, type) and issubclass(
|
values["args_only"] = (
|
||||||
schema, BaseModel
|
isinstance(schema, type)
|
||||||
|
and not isinstance(schema, GenericAlias)
|
||||||
|
and issubclass(schema, BaseModel)
|
||||||
)
|
)
|
||||||
elif values["args_only"] and isinstance(schema, Dict):
|
elif values["args_only"] and isinstance(schema, dict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If multiple pydantic schemas are provided then args_only should be"
|
"If multiple pydantic schemas are provided then args_only should be"
|
||||||
" False."
|
" False."
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -292,7 +295,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|||||||
attr_name: str
|
attr_name: str
|
||||||
"""The name of the attribute to return."""
|
"""The name of the attribute to return."""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a JSON object.
|
"""Parse the result of an LLM call to a JSON object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
from pydantic import SkipValidation, ValidationError
|
from pydantic import SkipValidation, ValidationError
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||||
@ -17,12 +16,12 @@ from langchain_core.utils.pydantic import TypeBaseModel
|
|||||||
|
|
||||||
|
|
||||||
def parse_tool_call(
|
def parse_tool_call(
|
||||||
raw_tool_call: Dict[str, Any],
|
raw_tool_call: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
return_id: bool = True,
|
return_id: bool = True,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
"""Parse a single tool call.
|
"""Parse a single tool call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -69,7 +68,7 @@ def parse_tool_call(
|
|||||||
|
|
||||||
|
|
||||||
def make_invalid_tool_call(
|
def make_invalid_tool_call(
|
||||||
raw_tool_call: Dict[str, Any],
|
raw_tool_call: dict[str, Any],
|
||||||
error_msg: Optional[str],
|
error_msg: Optional[str],
|
||||||
) -> InvalidToolCall:
|
) -> InvalidToolCall:
|
||||||
"""Create an InvalidToolCall from a raw tool call.
|
"""Create an InvalidToolCall from a raw tool call.
|
||||||
@ -90,12 +89,12 @@ def make_invalid_tool_call(
|
|||||||
|
|
||||||
|
|
||||||
def parse_tool_calls(
|
def parse_tool_calls(
|
||||||
raw_tool_calls: List[dict],
|
raw_tool_calls: list[dict],
|
||||||
*,
|
*,
|
||||||
partial: bool = False,
|
partial: bool = False,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
return_id: bool = True,
|
return_id: bool = True,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Parse a list of tool calls.
|
"""Parse a list of tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -111,7 +110,7 @@ def parse_tool_calls(
|
|||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If any of the tool calls are not valid JSON.
|
OutputParserException: If any of the tool calls are not valid JSON.
|
||||||
"""
|
"""
|
||||||
final_tools: List[Dict[str, Any]] = []
|
final_tools: list[dict[str, Any]] = []
|
||||||
exceptions = []
|
exceptions = []
|
||||||
for tool_call in raw_tool_calls:
|
for tool_call in raw_tool_calls:
|
||||||
try:
|
try:
|
||||||
@ -151,7 +150,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
|||||||
If no tool calls are found, None will be returned.
|
If no tool calls are found, None will be returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of tool calls.
|
"""Parse the result of an LLM call to a list of tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -217,7 +216,7 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
key_name: str
|
key_name: str
|
||||||
"""The type of tools to return."""
|
"""The type of tools to return."""
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of tool calls.
|
"""Parse the result of an LLM call to a list of tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -254,12 +253,12 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
class PydanticToolsParser(JsonOutputToolsParser):
|
class PydanticToolsParser(JsonOutputToolsParser):
|
||||||
"""Parse tools from OpenAI response."""
|
"""Parse tools from OpenAI response."""
|
||||||
|
|
||||||
tools: Annotated[List[TypeBaseModel], SkipValidation()]
|
tools: Annotated[list[TypeBaseModel], SkipValidation()]
|
||||||
"""The tools to parse."""
|
"""The tools to parse."""
|
||||||
|
|
||||||
# TODO: Support more granular streaming of objects. Currently only streams once all
|
# TODO: Support more granular streaming of objects. Currently only streams once all
|
||||||
# Pydantic object fields are present.
|
# Pydantic object fields are present.
|
||||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||||
"""Parse the result of an LLM call to a list of Pydantic objects.
|
"""Parse the result of an LLM call to a list of Pydantic objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Generic, List, Optional, Type
|
from typing import Annotated, Generic, Optional
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import SkipValidation
|
from pydantic import SkipValidation
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
@ -18,7 +17,7 @@ from langchain_core.utils.pydantic import (
|
|||||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||||
"""Parse an output using a pydantic model."""
|
"""Parse an output using a pydantic model."""
|
||||||
|
|
||||||
pydantic_object: Annotated[Type[TBaseModel], SkipValidation()] # type: ignore
|
pydantic_object: Annotated[type[TBaseModel], SkipValidation()] # type: ignore
|
||||||
"""The pydantic model to parse."""
|
"""The pydantic model to parse."""
|
||||||
|
|
||||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||||
@ -50,7 +49,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return OutputParserException(msg, llm_output=json_string)
|
return OutputParserException(msg, llm_output=json_string)
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: List[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> Optional[TBaseModel]:
|
) -> Optional[TBaseModel]:
|
||||||
"""Parse the result of an LLM call to a pydantic object.
|
"""Parse the result of an LLM call to a pydantic object.
|
||||||
|
|
||||||
@ -108,7 +107,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return "pydantic"
|
return "pydantic"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Type[TBaseModel]:
|
def OutputType(self) -> type[TBaseModel]:
|
||||||
"""Return the pydantic model."""
|
"""Return the pydantic model."""
|
||||||
return self.pydantic_object
|
return self.pydantic_object
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from typing import List
|
|
||||||
from typing import Optional as Optional
|
from typing import Optional as Optional
|
||||||
|
|
||||||
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseTransformOutputParser
|
||||||
@ -13,7 +12,7 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "output_parser"]
|
return ["langchain", "schema", "output_parser"]
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Iterator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
import xml
|
import xml
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
from xml.etree.ElementTree import TreeBuilder
|
from xml.etree.ElementTree import TreeBuilder
|
||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -57,7 +58,7 @@ class _StreamingParser:
|
|||||||
_parser = None
|
_parser = None
|
||||||
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
|
||||||
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
|
||||||
self.current_path: List[str] = []
|
self.current_path: list[str] = []
|
||||||
self.current_path_has_children = False
|
self.current_path_has_children = False
|
||||||
self.buffer = ""
|
self.buffer = ""
|
||||||
self.xml_started = False
|
self.xml_started = False
|
||||||
@ -140,7 +141,7 @@ class _StreamingParser:
|
|||||||
class XMLOutputParser(BaseTransformOutputParser):
|
class XMLOutputParser(BaseTransformOutputParser):
|
||||||
"""Parse an output using xml format."""
|
"""Parse an output using xml format."""
|
||||||
|
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
encoding_matcher: re.Pattern = re.compile(
|
encoding_matcher: re.Pattern = re.compile(
|
||||||
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
|
||||||
)
|
)
|
||||||
@ -169,7 +170,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
"""Return the format instructions for the XML output."""
|
"""Return the format instructions for the XML output."""
|
||||||
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
|
def parse(self, text: str) -> dict[str, Union[str, list[Any]]]:
|
||||||
"""Parse the output of an LLM call.
|
"""Parse the output of an LLM call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -234,13 +235,13 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
yield output
|
yield output
|
||||||
streaming_parser.close()
|
streaming_parser.close()
|
||||||
|
|
||||||
def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]:
|
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
|
||||||
"""Converts xml tree to python dictionary."""
|
"""Converts xml tree to python dictionary."""
|
||||||
if root.text and bool(re.search(r"\S", root.text)):
|
if root.text and bool(re.search(r"\S", root.text)):
|
||||||
# If root text contains any non-whitespace character it
|
# If root text contains any non-whitespace character it
|
||||||
# returns {root.tag: root.text}
|
# returns {root.tag: root.text}
|
||||||
return {root.tag: root.text}
|
return {root.tag: root.text}
|
||||||
result: Dict = {root.tag: []}
|
result: dict = {root.tag: []}
|
||||||
for child in root:
|
for child in root:
|
||||||
if len(child) == 0:
|
if len(child) == 0:
|
||||||
result[root.tag].append({child.tag: child.text})
|
result[root.tag].append({child.tag: child.text})
|
||||||
@ -253,7 +254,7 @@ class XMLOutputParser(BaseTransformOutputParser):
|
|||||||
return "xml"
|
return "xml"
|
||||||
|
|
||||||
|
|
||||||
def nested_element(path: List[str], elem: ET.Element) -> Any:
|
def nested_element(path: list[str], elem: ET.Element) -> Any:
|
||||||
"""Get nested element from path.
|
"""Get nested element from path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class ChatResult(BaseModel):
|
|||||||
for more information.
|
for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: List[ChatGeneration]
|
generations: list[ChatGeneration]
|
||||||
"""List of the chat generations.
|
"""List of the chat generations.
|
||||||
|
|
||||||
Generations is a list to allow for multiple candidate generations for a single
|
Generations is a list to allow for multiple candidate generations for a single
|
||||||
|
@ -7,7 +7,8 @@ They can be used to represent text, images, or chat message pieces.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Literal, Sequence, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Generic,
|
Generic,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -39,7 +39,7 @@ FormatOutputType = TypeVar("FormatOutputType")
|
|||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(
|
class BasePromptTemplate(
|
||||||
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
|
RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
|
||||||
):
|
):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ class BasePromptTemplate(
|
|||||||
"""optional_variables: A list of the names of the variables for placeholder
|
"""optional_variables: A list of the names of the variables for placeholder
|
||||||
or MessagePlaceholder that are optional. These variables are auto inferred
|
or MessagePlaceholder that are optional. These variables are auto inferred
|
||||||
from the prompt and user need not provide them."""
|
from the prompt and user need not provide them."""
|
||||||
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
|
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
|
||||||
"""A dictionary of the types of the variables the prompt template expects.
|
"""A dictionary of the types of the variables the prompt template expects.
|
||||||
If not provided, all variables are assumed to be strings."""
|
If not provided, all variables are assumed to be strings."""
|
||||||
output_parser: Optional[BaseOutputParser] = None
|
output_parser: Optional[BaseOutputParser] = None
|
||||||
@ -60,7 +60,7 @@ class BasePromptTemplate(
|
|||||||
|
|
||||||
Partial variables populate the template so that you don't need to
|
Partial variables populate the template so that you don't need to
|
||||||
pass them in every time you call the prompt."""
|
pass them in every time you call the prompt."""
|
||||||
metadata: Optional[Dict[str, Any]] = None # noqa: UP006
|
metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
|
||||||
"""Metadata to be used for tracing."""
|
"""Metadata to be used for tracing."""
|
||||||
tags: Optional[list[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
"""Tags to be used for tracing."""
|
"""Tags to be used for tracing."""
|
||||||
|
@ -3,15 +3,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -25,7 +23,6 @@ from pydantic import (
|
|||||||
SkipValidation,
|
SkipValidation,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.load import Serializable
|
from langchain_core.load import Serializable
|
||||||
@ -816,9 +813,9 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
|
|||||||
|
|
||||||
MessageLikeRepresentation = Union[
|
MessageLikeRepresentation = Union[
|
||||||
MessageLike,
|
MessageLike,
|
||||||
Tuple[
|
tuple[
|
||||||
Union[str, Type],
|
Union[str, type],
|
||||||
Union[str, List[dict], List[object]],
|
Union[str, list[dict], list[object]],
|
||||||
],
|
],
|
||||||
str,
|
str,
|
||||||
]
|
]
|
||||||
@ -1017,7 +1014,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
),
|
),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
|
cast(type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> list[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
@ -1083,7 +1080,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
values["partial_variables"][message.variable_name] = []
|
values["partial_variables"][message.variable_name] = []
|
||||||
optional_variables.add(message.variable_name)
|
optional_variables.add(message.variable_name)
|
||||||
if message.variable_name not in input_types:
|
if message.variable_name not in input_types:
|
||||||
input_types[message.variable_name] = List[AnyMessage]
|
input_types[message.variable_name] = list[AnyMessage]
|
||||||
if "partial_variables" in values:
|
if "partial_variables" in values:
|
||||||
input_vars = input_vars - set(values["partial_variables"])
|
input_vars = input_vars - set(values["partial_variables"])
|
||||||
if optional_variables:
|
if optional_variables:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -16,7 +16,7 @@ from langchain_core.prompts.string import (
|
|||||||
class FewShotPromptWithTemplates(StringPromptTemplate):
|
class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
|
|
||||||
examples: Optional[List[dict]] = None
|
examples: Optional[list[dict]] = None
|
||||||
"""Examples to format into the prompt.
|
"""Examples to format into the prompt.
|
||||||
Either this or example_selector should be provided."""
|
Either this or example_selector should be provided."""
|
||||||
|
|
||||||
@ -43,13 +43,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "few_shot_with_templates"]
|
return ["langchain", "prompts", "few_shot_with_templates"]
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_examples_and_selector(cls, values: Dict) -> Any:
|
def check_examples_and_selector(cls, values: dict) -> Any:
|
||||||
"""Check that one and only one of examples/example_selector are provided."""
|
"""Check that one and only one of examples/example_selector are provided."""
|
||||||
examples = values.get("examples", None)
|
examples = values.get("examples", None)
|
||||||
example_selector = values.get("example_selector", None)
|
example_selector = values.get("example_selector", None)
|
||||||
@ -93,7 +93,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
extra="forbid",
|
extra="forbid",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
def _get_examples(self, **kwargs: Any) -> list[dict]:
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
elif self.example_selector is not None:
|
||||||
@ -101,7 +101,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
async def _aget_examples(self, **kwargs: Any) -> list[dict]:
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return self.examples
|
return self.examples
|
||||||
elif self.example_selector is not None:
|
elif self.example_selector is not None:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
return "image-prompt"
|
return "image-prompt"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "image"]
|
return ["langchain", "prompts", "image"]
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ def _load_prompt_from_file(
|
|||||||
return load_prompt_from_config(config)
|
return load_prompt_from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
|
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||||
"""Load chat prompt from config"""
|
"""Load chat prompt from config"""
|
||||||
|
|
||||||
messages = config.pop("messages")
|
messages = config.pop("messages")
|
||||||
@ -194,7 +194,7 @@ def _load_chat_prompt(config: Dict) -> ChatPromptTemplate:
|
|||||||
return ChatPromptTemplate.from_template(template=template, **config)
|
return ChatPromptTemplate.from_template(template=template, **config)
|
||||||
|
|
||||||
|
|
||||||
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
|
type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||||
"prompt": _load_prompt,
|
"prompt": _load_prompt,
|
||||||
"few_shot": _load_few_shot_prompt,
|
"few_shot": _load_few_shot_prompt,
|
||||||
"chat": _load_chat_prompt,
|
"chat": _load_chat_prompt,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
from typing import Optional as Optional
|
from typing import Optional as Optional
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
@ -8,7 +8,7 @@ from langchain_core.prompts.base import BasePromptTemplate
|
|||||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
|
||||||
return {k: inputs[k] for k in input_variables}
|
return {k: inputs[k] for k in input_variables}
|
||||||
|
|
||||||
|
|
||||||
@ -28,17 +28,17 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
|||||||
|
|
||||||
final_prompt: BasePromptTemplate
|
final_prompt: BasePromptTemplate
|
||||||
"""The final prompt that is returned."""
|
"""The final prompt that is returned."""
|
||||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
pipeline_prompts: list[tuple[str, BasePromptTemplate]]
|
||||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "pipeline"]
|
return ["langchain", "prompts", "pipeline"]
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_input_variables(cls, values: Dict) -> Any:
|
def get_input_variables(cls, values: dict) -> Any:
|
||||||
"""Get input variables."""
|
"""Get input variables."""
|
||||||
created_variables = set()
|
created_variables = set()
|
||||||
all_variables = set()
|
all_variables = set()
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable
|
||||||
|
|
||||||
from pydantic import BaseModel, create_model
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def mustache_template_vars(
|
|||||||
return vars
|
return vars
|
||||||
|
|
||||||
|
|
||||||
Defs = Dict[str, "Defs"]
|
Defs = dict[str, "Defs"]
|
||||||
|
|
||||||
|
|
||||||
def mustache_schema(
|
def mustache_schema(
|
||||||
|
@ -1,13 +1,8 @@
|
|||||||
|
from collections.abc import Iterator, Mapping, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,16 +27,16 @@ from langchain_core.utils import get_pydantic_field_names
|
|||||||
class StructuredPrompt(ChatPromptTemplate):
|
class StructuredPrompt(ChatPromptTemplate):
|
||||||
"""Structured prompt template for a language model."""
|
"""Structured prompt template for a language model."""
|
||||||
|
|
||||||
schema_: Union[Dict, Type[BaseModel]]
|
schema_: Union[dict, type[BaseModel]]
|
||||||
"""Schema for the structured prompt."""
|
"""Schema for the structured prompt."""
|
||||||
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
structured_output_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[MessageLikeRepresentation],
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
schema_: Optional[Union[dict, type[BaseModel]]] = None,
|
||||||
*,
|
*,
|
||||||
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
structured_output_kwargs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
schema_ = schema_ or kwargs.pop("schema")
|
schema_ = schema_ or kwargs.pop("schema")
|
||||||
@ -56,7 +51,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
|
|
||||||
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
For example, if the class is `langchain.llms.openai.OpenAI`, then the
|
||||||
@ -68,7 +63,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
def from_messages_and_schema(
|
def from_messages_and_schema(
|
||||||
cls,
|
cls,
|
||||||
messages: Sequence[MessageLikeRepresentation],
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
schema: Union[Dict, Type[BaseModel]],
|
schema: Union[dict, type[BaseModel]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create a chat prompt template from a variety of message formats.
|
"""Create a chat prompt template from a variety of message formats.
|
||||||
@ -118,7 +113,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
Callable[[Iterator[Any]], Iterator[Other]],
|
Callable[[Iterator[Any]], Iterator[Other]],
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSerializable[Dict, Other]:
|
) -> RunnableSerializable[dict, Other]:
|
||||||
return self.pipe(other)
|
return self.pipe(other)
|
||||||
|
|
||||||
def pipe(
|
def pipe(
|
||||||
@ -130,7 +125,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
) -> RunnableSerializable[Dict, Other]:
|
) -> RunnableSerializable[dict, Other]:
|
||||||
"""Pipe the structured prompt to a language model.
|
"""Pipe the structured prompt to a language model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -24,7 +24,7 @@ from __future__ import annotations
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
RetrieverInput = str
|
RetrieverInput = str
|
||||||
RetrieverOutput = List[Document]
|
RetrieverOutput = list[Document]
|
||||||
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
||||||
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
||||||
|
|
||||||
|
@ -6,36 +6,37 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import (
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Coroutine,
|
||||||
|
Iterator,
|
||||||
|
Mapping,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import groupby, tee
|
from itertools import groupby, tee
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from types import GenericAlias
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
|
||||||
AsyncIterator,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
get_type_hints,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
from pydantic import BaseModel, ConfigDict, Field, RootModel
|
||||||
from typing_extensions import Literal, get_args, get_type_hints
|
from typing_extensions import Literal, get_args
|
||||||
|
|
||||||
from langchain_core._api import beta_decorator
|
from langchain_core._api import beta_decorator
|
||||||
from langchain_core.load.serializable import (
|
from langchain_core.load.serializable import (
|
||||||
@ -340,7 +341,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
root_type = self.InputType
|
root_type = self.InputType
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
@ -408,7 +413,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
root_type = self.OutputType
|
root_type = self.OutputType
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
@ -771,10 +780,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
return cast(List[Output], [invoke(inputs[0], configs[0])])
|
return cast(list[Output], [invoke(inputs[0], configs[0])])
|
||||||
|
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
return cast(list[Output], list(executor.map(invoke, inputs, configs)))
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def batch_as_completed(
|
def batch_as_completed(
|
||||||
@ -2024,7 +2033,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
for run_manager in run_managers:
|
for run_manager in run_managers:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[Output], [e for _ in input])
|
return cast(list[Output], [e for _ in input])
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -2036,7 +2045,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(out)
|
run_manager.on_chain_end(out)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast(List[Output], output)
|
return cast(list[Output], output)
|
||||||
else:
|
else:
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
@ -2099,7 +2108,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||||
)
|
)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[Output], [e for _ in input])
|
return cast(list[Output], [e for _ in input])
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -2113,7 +2122,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
coros.append(run_manager.on_chain_end(out))
|
coros.append(run_manager.on_chain_end(out))
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast(List[Output], output)
|
return cast(list[Output], output)
|
||||||
else:
|
else:
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
@ -3171,7 +3180,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
for rm in run_managers:
|
for rm in run_managers:
|
||||||
rm.on_chain_error(e)
|
rm.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[Output], [e for _ in inputs])
|
return cast(list[Output], [e for _ in inputs])
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -3183,7 +3192,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
run_manager.on_chain_end(out)
|
run_manager.on_chain_end(out)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast(List[Output], inputs)
|
return cast(list[Output], inputs)
|
||||||
else:
|
else:
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
@ -3298,7 +3307,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[Output], [e for _ in inputs])
|
return cast(list[Output], [e for _ in inputs])
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -3312,7 +3321,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
coros.append(run_manager.on_chain_end(out))
|
coros.append(run_manager.on_chain_end(out))
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
if return_exceptions or first_exception is None:
|
if return_exceptions or first_exception is None:
|
||||||
return cast(List[Output], inputs)
|
return cast(list[Output], inputs)
|
||||||
else:
|
else:
|
||||||
raise first_exception
|
raise first_exception
|
||||||
|
|
||||||
@ -3420,7 +3429,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||||
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping
|
"""Runnable that runs a mapping of Runnables in parallel, and returns a mapping
|
||||||
of their outputs.
|
of their outputs.
|
||||||
|
|
||||||
@ -4071,7 +4080,11 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
func = getattr(self, "_transform", None) or self._atransform
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
module = getattr(func, "__module__", None)
|
module = getattr(func, "__module__", None)
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
@ -4106,7 +4119,11 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
func = getattr(self, "_transform", None) or self._atransform
|
func = getattr(self, "_transform", None) or self._atransform
|
||||||
module = getattr(func, "__module__", None)
|
module = getattr(func, "__module__", None)
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
@ -4369,7 +4386,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
module = getattr(func, "__module__", None)
|
module = getattr(func, "__module__", None)
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
self.get_name("Input"),
|
self.get_name("Input"),
|
||||||
root=List[Any],
|
root=list[Any],
|
||||||
# To create the schema, we need to provide the module
|
# To create the schema, we need to provide the module
|
||||||
# where the underlying function is defined.
|
# where the underlying function is defined.
|
||||||
# This allows pydantic to resolve type annotations appropriately.
|
# This allows pydantic to resolve type annotations appropriately.
|
||||||
@ -4420,7 +4437,11 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
func = getattr(self, "func", None) or self.afunc
|
func = getattr(self, "func", None) or self.afunc
|
||||||
module = getattr(func, "__module__", None)
|
module = getattr(func, "__module__", None)
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
@ -4921,7 +4942,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||||
"""Runnable that delegates calls to another Runnable
|
"""Runnable that delegates calls to another Runnable
|
||||||
with each element of the input sequence.
|
with each element of the input sequence.
|
||||||
|
|
||||||
@ -4938,7 +4959,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Any:
|
def InputType(self) -> Any:
|
||||||
return List[self.bound.InputType] # type: ignore[name-defined]
|
return list[self.bound.InputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
@ -4946,7 +4967,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
self.get_name("Input"),
|
self.get_name("Input"),
|
||||||
root=(
|
root=(
|
||||||
List[self.bound.get_input_schema(config)], # type: ignore
|
list[self.bound.get_input_schema(config)], # type: ignore
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
# create model needs access to appropriate type annotations to be
|
# create model needs access to appropriate type annotations to be
|
||||||
@ -4961,7 +4982,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[list[Output]]:
|
def OutputType(self) -> type[list[Output]]:
|
||||||
return List[self.bound.OutputType] # type: ignore[name-defined]
|
return list[self.bound.OutputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
@ -4969,7 +4990,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
schema = self.bound.get_output_schema(config)
|
schema = self.bound.get_output_schema(config)
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
self.get_name("Output"),
|
self.get_name("Output"),
|
||||||
root=List[schema], # type: ignore[valid-type]
|
root=list[schema], # type: ignore[valid-type]
|
||||||
# create model needs access to appropriate type annotations to be
|
# create model needs access to appropriate type annotations to be
|
||||||
# able to construct the pydantic model.
|
# able to construct the pydantic model.
|
||||||
# When we create the model, we pass information about the namespace
|
# When we create the model, we pass information about the namespace
|
||||||
@ -5255,7 +5276,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
@property
|
@property
|
||||||
def InputType(self) -> type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
return (
|
return (
|
||||||
cast(Type[Input], self.custom_input_type)
|
cast(type[Input], self.custom_input_type)
|
||||||
if self.custom_input_type is not None
|
if self.custom_input_type is not None
|
||||||
else self.bound.InputType
|
else self.bound.InputType
|
||||||
)
|
)
|
||||||
@ -5263,7 +5284,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
@property
|
@property
|
||||||
def OutputType(self) -> type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return (
|
return (
|
||||||
cast(Type[Output], self.custom_output_type)
|
cast(type[Output], self.custom_output_type)
|
||||||
if self.custom_output_type is not None
|
if self.custom_output_type is not None
|
||||||
else self.bound.OutputType
|
else self.bound.OutputType
|
||||||
)
|
)
|
||||||
@ -5336,7 +5357,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
list[RunnableConfig],
|
||||||
[self._merge_configs(conf) for conf in config],
|
[self._merge_configs(conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -5358,7 +5379,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
) -> list[Output]:
|
) -> list[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
list[RunnableConfig],
|
||||||
[self._merge_configs(conf) for conf in config],
|
[self._merge_configs(conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -5400,7 +5421,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
) -> Iterator[tuple[int, Union[Output, Exception]]]:
|
) -> Iterator[tuple[int, Union[Output, Exception]]]:
|
||||||
if isinstance(config, Sequence):
|
if isinstance(config, Sequence):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
list[RunnableConfig],
|
||||||
[self._merge_configs(conf) for conf in config],
|
[self._merge_configs(conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -5451,7 +5472,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
|
) -> AsyncIterator[tuple[int, Union[Output, Exception]]]:
|
||||||
if isinstance(config, Sequence):
|
if isinstance(config, Sequence):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
list[RunnableConfig],
|
||||||
[self._merge_configs(conf) for conf in config],
|
[self._merge_configs(conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1,15 +1,8 @@
|
|||||||
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -69,13 +62,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
branch.invoke(None) # "goodbye"
|
branch.invoke(None) # "goodbye"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
branches: Sequence[tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||||
default: Runnable[Input, Output]
|
default: Runnable[Input, Output]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*branches: Union[
|
*branches: Union[
|
||||||
Tuple[
|
tuple[
|
||||||
Union[
|
Union[
|
||||||
Runnable[Input, bool],
|
Runnable[Input, bool],
|
||||||
Callable[[Input], bool],
|
Callable[[Input], bool],
|
||||||
@ -149,13 +142,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
runnables = (
|
runnables = (
|
||||||
[self.default]
|
[self.default]
|
||||||
+ [r for _, r in self.branches]
|
+ [r for _, r in self.branches]
|
||||||
@ -172,7 +165,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
|||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
from langchain_core.beta.runnables.context import (
|
from langchain_core.beta.runnables.context import (
|
||||||
CONTEXT_CONFIG_PREFIX,
|
CONTEXT_CONFIG_PREFIX,
|
||||||
CONTEXT_CONFIG_SUFFIX_SET,
|
CONTEXT_CONFIG_SUFFIX_SET,
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
|
||||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar, copy_context
|
from contextvars import ContextVar, copy_context
|
||||||
@ -10,14 +11,8 @@ from functools import partial
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -43,7 +38,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
# Pydantic validates through typed dicts, but
|
# Pydantic validates through typed dicts, but
|
||||||
# the callbacks need forward refs updated
|
# the callbacks need forward refs updated
|
||||||
Callbacks = Optional[Union[List, Any]]
|
Callbacks = Optional[Union[list, Any]]
|
||||||
|
|
||||||
|
|
||||||
class EmptyDict(TypedDict, total=False):
|
class EmptyDict(TypedDict, total=False):
|
||||||
|
@ -3,20 +3,16 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import threading
|
import threading
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
|
from collections.abc import Mapping as Mapping
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from typing import Mapping as Mapping
|
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
@ -176,10 +172,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
return cast(List[Output], [invoke(prepared[0], inputs[0])])
|
return cast(list[Output], [invoke(prepared[0], inputs[0])])
|
||||||
|
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return cast(List[Output], list(executor.map(invoke, prepared, inputs)))
|
return cast(list[Output], list(executor.map(invoke, prepared, inputs)))
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -562,7 +558,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
for v in list(self.alternatives.keys()) + [self.default_key]
|
for v in list(self.alternatives.keys()) + [self.default_key]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
|
_enums_for_spec[self.which] = cast(type[StrEnum], which_enum)
|
||||||
return get_unique_config_specs(
|
return get_unique_config_specs(
|
||||||
# which alternative
|
# which alternative
|
||||||
[
|
[
|
||||||
@ -694,7 +690,7 @@ def make_options_spec(
|
|||||||
spec.name or spec.id,
|
spec.name or spec.id,
|
||||||
((v, v) for v in list(spec.options.keys())),
|
((v, v) for v in list(spec.options.keys())),
|
||||||
)
|
)
|
||||||
_enums_for_spec[spec] = cast(Type[StrEnum], enum)
|
_enums_for_spec[spec] = cast(type[StrEnum], enum)
|
||||||
if isinstance(spec, ConfigurableFieldSingleOption):
|
if isinstance(spec, ConfigurableFieldSingleOption):
|
||||||
return ConfigurableFieldSpec(
|
return ConfigurableFieldSpec(
|
||||||
id=spec.id,
|
id=spec.id,
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -96,7 +90,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
"""The Runnable to run first."""
|
"""The Runnable to run first."""
|
||||||
fallbacks: Sequence[Runnable[Input, Output]]
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
"""A sequence of fallbacks to try."""
|
"""A sequence of fallbacks to try."""
|
||||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,)
|
||||||
"""The exceptions on which fallbacks should be tried.
|
"""The exceptions on which fallbacks should be tried.
|
||||||
|
|
||||||
Any exception that is not a subclass of these exceptions will be raised immediately.
|
Any exception that is not a subclass of these exceptions will be raised immediately.
|
||||||
@ -112,25 +106,25 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def InputType(self) -> Type[Input]:
|
def InputType(self) -> type[Input]:
|
||||||
return self.runnable.InputType
|
return self.runnable.InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def OutputType(self) -> Type[Output]:
|
def OutputType(self) -> type[Output]:
|
||||||
return self.runnable.OutputType
|
return self.runnable.OutputType
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return self.runnable.get_input_schema(config)
|
return self.runnable.get_input_schema(config)
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return self.runnable.get_output_schema(config)
|
return self.runnable.get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||||
return get_unique_config_specs(
|
return get_unique_config_specs(
|
||||||
spec
|
spec
|
||||||
for step in [self.runnable, *self.fallbacks]
|
for step in [self.runnable, *self.fallbacks]
|
||||||
@ -142,7 +136,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
@ -252,12 +246,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
@ -296,9 +290,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
to_return: Dict[int, Any] = {}
|
to_return: dict[int, Any] = {}
|
||||||
run_again = {i: input for i, input in enumerate(inputs)}
|
run_again = {i: input for i, input in enumerate(inputs)}
|
||||||
handled_exceptions: Dict[int, BaseException] = {}
|
handled_exceptions: dict[int, BaseException] = {}
|
||||||
first_to_raise = None
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
outputs = runnable.batch(
|
outputs = runnable.batch(
|
||||||
@ -344,12 +338,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> list[Output]:
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
if self.exception_key is not None and not all(
|
if self.exception_key is not None and not all(
|
||||||
@ -378,7 +372,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
for config in configs
|
for config in configs
|
||||||
]
|
]
|
||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
None,
|
None,
|
||||||
@ -392,7 +386,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
run_again = {i: input for i, input in enumerate(inputs)}
|
run_again = {i: input for i, input in enumerate(inputs)}
|
||||||
handled_exceptions: Dict[int, BaseException] = {}
|
handled_exceptions: dict[int, BaseException] = {}
|
||||||
first_to_raise = None
|
first_to_raise = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
outputs = await runnable.abatch(
|
outputs = await runnable.abatch(
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -11,7 +12,6 @@ from typing import (
|
|||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
|
@ -3,7 +3,8 @@ Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.runnables.graph import Edge as LangEdge
|
from langchain_core.runnables.graph import Edge as LangEdge
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.runnables.graph import (
|
from langchain_core.runnables.graph import (
|
||||||
CurveStyle,
|
CurveStyle,
|
||||||
@ -15,8 +15,8 @@ MARKDOWN_SPECIAL_CHARS = "*_`"
|
|||||||
|
|
||||||
|
|
||||||
def draw_mermaid(
|
def draw_mermaid(
|
||||||
nodes: Dict[str, Node],
|
nodes: dict[str, Node],
|
||||||
edges: List[Edge],
|
edges: list[Edge],
|
||||||
*,
|
*,
|
||||||
first_node: Optional[str] = None,
|
first_node: Optional[str] = None,
|
||||||
last_node: Optional[str] = None,
|
last_node: Optional[str] = None,
|
||||||
@ -87,7 +87,7 @@ def draw_mermaid(
|
|||||||
mermaid_graph += f"\t{node_label}\n"
|
mermaid_graph += f"\t{node_label}\n"
|
||||||
|
|
||||||
# Group edges by their common prefixes
|
# Group edges by their common prefixes
|
||||||
edge_groups: Dict[str, List[Edge]] = {}
|
edge_groups: dict[str, list[Edge]] = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
src_parts = edge.source.split(":")
|
src_parts = edge.source.split(":")
|
||||||
tgt_parts = edge.target.split(":")
|
tgt_parts = edge.target.split(":")
|
||||||
@ -98,7 +98,7 @@ def draw_mermaid(
|
|||||||
|
|
||||||
seen_subgraphs = set()
|
seen_subgraphs = set()
|
||||||
|
|
||||||
def add_subgraph(edges: List[Edge], prefix: str) -> None:
|
def add_subgraph(edges: list[Edge], prefix: str) -> None:
|
||||||
nonlocal mermaid_graph
|
nonlocal mermaid_graph
|
||||||
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
self_loop = len(edges) == 1 and edges[0].source == edges[0].target
|
||||||
if prefix and not self_loop:
|
if prefix and not self_loop:
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from types import GenericAlias
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
|
MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
|
||||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||||
|
|
||||||
|
|
||||||
@ -419,7 +419,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
"""
|
"""
|
||||||
root_type = self.OutputType
|
root_type = self.OutputType
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if (
|
||||||
|
inspect.isclass(root_type)
|
||||||
|
and not isinstance(root_type, GenericAlias)
|
||||||
|
and issubclass(root_type, BaseModel)
|
||||||
|
):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model_v2(
|
return create_model_v2(
|
||||||
|
@ -5,15 +5,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
|
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -349,7 +345,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
|||||||
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
|
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
|
||||||
|
|
||||||
|
|
||||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||||
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||||
|
|
||||||
The `RunnableAssign` class takes input dictionaries and, through a
|
The `RunnableAssign` class takes input dictionaries and, through a
|
||||||
@ -564,7 +560,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
if filtered:
|
if filtered:
|
||||||
yield filtered
|
yield filtered
|
||||||
# yield map output
|
# yield map output
|
||||||
yield cast(Dict[str, Any], first_map_chunk_future.result())
|
yield cast(dict[str, Any], first_map_chunk_future.result())
|
||||||
for chunk in map_output:
|
for chunk in map_output:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@ -650,7 +646,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||||
"""Runnable that picks keys from Dict[str, Any] inputs.
|
"""Runnable that picks keys from Dict[str, Any] inputs.
|
||||||
|
|
||||||
RunnablePick class represents a Runnable that selectively picks keys from a
|
RunnablePick class represents a Runnable that selectively picks keys from a
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -98,7 +94,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
retryable_chain = chain.with_retry()
|
retryable_chain = chain.with_retry()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
|
retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
|
||||||
"""The exception types to retry on. By default all exceptions are retried.
|
"""The exception types to retry on. By default all exceptions are retried.
|
||||||
|
|
||||||
In general you should only retry on exceptions that are likely to be
|
In general you should only retry on exceptions that are likely to be
|
||||||
@ -115,13 +111,13 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
"""The maximum number of attempts to retry the Runnable."""
|
"""The maximum number of attempts to retry the Runnable."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "runnable"]
|
return ["langchain", "schema", "runnable"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
def _kwargs_retrying(self) -> dict[str, Any]:
|
||||||
kwargs: Dict[str, Any] = dict()
|
kwargs: dict[str, Any] = dict()
|
||||||
|
|
||||||
if self.max_attempt_number:
|
if self.max_attempt_number:
|
||||||
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
|
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
|
||||||
@ -152,10 +148,10 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
def _patch_config_list(
|
def _patch_config_list(
|
||||||
self,
|
self,
|
||||||
config: List[RunnableConfig],
|
config: list[RunnableConfig],
|
||||||
run_manager: List["T"],
|
run_manager: list["T"],
|
||||||
retry_state: RetryCallState,
|
retry_state: RetryCallState,
|
||||||
) -> List[RunnableConfig]:
|
) -> list[RunnableConfig]:
|
||||||
return [
|
return [
|
||||||
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||||
]
|
]
|
||||||
@ -208,17 +204,17 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
def _batch(
|
def _batch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
run_manager: List["CallbackManagerForChainRun"],
|
run_manager: list["CallbackManagerForChainRun"],
|
||||||
config: List[RunnableConfig],
|
config: list[RunnableConfig],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Union[Output, Exception]]:
|
) -> list[Union[Output, Exception]]:
|
||||||
results_map: Dict[int, Output] = {}
|
results_map: dict[int, Output] = {}
|
||||||
|
|
||||||
def pending(iterable: List[U]) -> List[U]:
|
def pending(iterable: list[U]) -> list[U]:
|
||||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
not_set: List[Output] = []
|
not_set: list[Output] = []
|
||||||
result = not_set
|
result = not_set
|
||||||
try:
|
try:
|
||||||
for attempt in self._sync_retrying():
|
for attempt in self._sync_retrying():
|
||||||
@ -250,9 +246,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
except RetryError as e:
|
except RetryError as e:
|
||||||
if result is not_set:
|
if result is not_set:
|
||||||
result = cast(List[Output], [e] * len(inputs))
|
result = cast(list[Output], [e] * len(inputs))
|
||||||
|
|
||||||
outputs: List[Union[Output, Exception]] = []
|
outputs: list[Union[Output, Exception]] = []
|
||||||
for idx, _ in enumerate(inputs):
|
for idx, _ in enumerate(inputs):
|
||||||
if idx in results_map:
|
if idx in results_map:
|
||||||
outputs.append(results_map[idx])
|
outputs.append(results_map[idx])
|
||||||
@ -262,29 +258,29 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Output]:
|
) -> list[Output]:
|
||||||
return self._batch_with_config(
|
return self._batch_with_config(
|
||||||
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _abatch(
|
async def _abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
run_manager: List["AsyncCallbackManagerForChainRun"],
|
run_manager: list["AsyncCallbackManagerForChainRun"],
|
||||||
config: List[RunnableConfig],
|
config: list[RunnableConfig],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Union[Output, Exception]]:
|
) -> list[Union[Output, Exception]]:
|
||||||
results_map: Dict[int, Output] = {}
|
results_map: dict[int, Output] = {}
|
||||||
|
|
||||||
def pending(iterable: List[U]) -> List[U]:
|
def pending(iterable: list[U]) -> list[U]:
|
||||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||||
|
|
||||||
not_set: List[Output] = []
|
not_set: list[Output] = []
|
||||||
result = not_set
|
result = not_set
|
||||||
try:
|
try:
|
||||||
async for attempt in self._async_retrying():
|
async for attempt in self._async_retrying():
|
||||||
@ -316,9 +312,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
attempt.retry_state.set_result(result)
|
attempt.retry_state.set_result(result)
|
||||||
except RetryError as e:
|
except RetryError as e:
|
||||||
if result is not_set:
|
if result is not_set:
|
||||||
result = cast(List[Output], [e] * len(inputs))
|
result = cast(list[Output], [e] * len(inputs))
|
||||||
|
|
||||||
outputs: List[Union[Output, Exception]] = []
|
outputs: list[Union[Output, Exception]] = []
|
||||||
for idx, _ in enumerate(inputs):
|
for idx, _ in enumerate(inputs):
|
||||||
if idx in results_map:
|
if idx in results_map:
|
||||||
outputs.append(results_map[idx])
|
outputs.append(results_map[idx])
|
||||||
@ -328,12 +324,12 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
|||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: list[Input],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Output]:
|
) -> list[Output]:
|
||||||
return await self._abatch_with_config(
|
return await self._abatch_with_config(
|
||||||
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -154,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
|||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return cast(
|
return cast(
|
||||||
List[Output],
|
list[Output],
|
||||||
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Literal, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
|
@ -6,23 +6,24 @@ import ast
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import (
|
||||||
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Coroutine,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterable,
|
|
||||||
AsyncIterator,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Mapping,
|
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -430,7 +431,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
|||||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||||
|
|
||||||
|
|
||||||
class AddableDict(Dict[str, Any]):
|
class AddableDict(dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Dictionary that can be added to another dictionary.
|
Dictionary that can be added to another dictionary.
|
||||||
"""
|
"""
|
||||||
|
@ -7,16 +7,11 @@ The primary goal of these storages is to support implementation of caching.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -84,7 +79,7 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||||
"""Get the values associated with the given keys.
|
"""Get the values associated with the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -95,7 +90,7 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
If a key is not found, the corresponding value will be None.
|
If a key is not found, the corresponding value will be None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||||
"""Async get the values associated with the given keys.
|
"""Async get the values associated with the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -108,14 +103,14 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
return await run_in_executor(None, self.mget, keys)
|
return await run_in_executor(None, self.mget, keys)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||||
"""Set the values for the given keys.
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||||
"""Async set the values for the given keys.
|
"""Async set the values for the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -184,9 +179,9 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize an empty store."""
|
"""Initialize an empty store."""
|
||||||
self.store: Dict[str, V] = {}
|
self.store: dict[str, V] = {}
|
||||||
|
|
||||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
def mget(self, keys: Sequence[str]) -> list[Optional[V]]:
|
||||||
"""Get the values associated with the given keys.
|
"""Get the values associated with the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -198,7 +193,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
"""
|
"""
|
||||||
return [self.store.get(key) for key in keys]
|
return [self.store.get(key) for key in keys]
|
||||||
|
|
||||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
async def amget(self, keys: Sequence[str]) -> list[Optional[V]]:
|
||||||
"""Async get the values associated with the given keys.
|
"""Async get the values associated with the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -210,7 +205,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
"""
|
"""
|
||||||
return self.mget(keys)
|
return self.mget(keys)
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
def mset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
|
||||||
"""Set the values for the given keys.
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -222,7 +217,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
for key, value in key_value_pairs:
|
for key, value in key_value_pairs:
|
||||||
self.store[key] = value
|
self.store[key] = value
|
||||||
|
|
||||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
async def amset(self, key_value_pairs: Sequence[tuple[str, V]]) -> None:
|
||||||
"""Async set the values for the given keys.
|
"""Async set the values for the given keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Sequence, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
for debugging purposes.
|
for debugging purposes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
def _get_sub_deps(packages: Sequence[str]) -> List[str]:
|
def _get_sub_deps(packages: Sequence[str]) -> list[str]:
|
||||||
"""Get any specified sub-dependencies."""
|
"""Get any specified sub-dependencies."""
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
|
@ -7,15 +7,15 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -36,7 +36,6 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -324,7 +323,7 @@ class ToolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
|
from typing import Any, Callable, Literal, Optional, Union, get_type_hints
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from langchain_core.tools.structured import StructuredTool
|
|||||||
def tool(
|
def tool(
|
||||||
*args: Union[str, Callable, Runnable],
|
*args: Union[str, Callable, Runnable],
|
||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type] = None,
|
args_schema: Optional[type] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
@ -160,7 +160,7 @@ def tool(
|
|||||||
|
|
||||||
coroutine = ainvoke_wrapper
|
coroutine = ainvoke_wrapper
|
||||||
func = invoke_wrapper
|
func = invoke_wrapper
|
||||||
schema: Optional[Type[BaseModel]] = runnable.input_schema
|
schema: Optional[type[BaseModel]] = runnable.input_schema
|
||||||
description = repr(runnable)
|
description = repr(runnable)
|
||||||
elif inspect.iscoroutinefunction(dec_func):
|
elif inspect.iscoroutinefunction(dec_func):
|
||||||
coroutine = dec_func
|
coroutine = dec_func
|
||||||
@ -234,8 +234,8 @@ def _get_description_from_runnable(runnable: Runnable) -> str:
|
|||||||
def _get_schema_from_runnable_and_arg_types(
|
def _get_schema_from_runnable_and_arg_types(
|
||||||
runnable: Runnable,
|
runnable: Runnable,
|
||||||
name: str,
|
name: str,
|
||||||
arg_types: Optional[Dict[str, Type]] = None,
|
arg_types: Optional[dict[str, type]] = None,
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
"""Infer args_schema for tool."""
|
"""Infer args_schema for tool."""
|
||||||
if arg_types is None:
|
if arg_types is None:
|
||||||
try:
|
try:
|
||||||
@ -252,11 +252,11 @@ def _get_schema_from_runnable_and_arg_types(
|
|||||||
|
|
||||||
def convert_runnable_to_tool(
|
def convert_runnable_to_tool(
|
||||||
runnable: Runnable,
|
runnable: Runnable,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[type[BaseModel]] = None,
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
arg_types: Optional[Dict[str, Type]] = None,
|
arg_types: Optional[dict[str, type]] = None,
|
||||||
) -> BaseTool:
|
) -> BaseTool:
|
||||||
"""Convert a Runnable into a BaseTool.
|
"""Convert a Runnable into a BaseTool.
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Callable, List
|
from typing import Callable
|
||||||
|
|
||||||
from langchain_core.tools.base import BaseTool
|
from langchain_core.tools.base import BaseTool
|
||||||
|
|
||||||
ToolsRenderer = Callable[[List[BaseTool]], str]
|
ToolsRenderer = Callable[[list[BaseTool]], str]
|
||||||
|
|
||||||
|
|
||||||
def render_text_description(tools: list[BaseTool]) -> str:
|
def render_text_description(tools: list[BaseTool]) -> str:
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Awaitable
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
@ -12,7 +13,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SkipValidation
|
from pydantic import BaseModel, Field, SkipValidation
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""Internal tracers used for stream_log and astream events implementations."""
|
"""Internal tracers used for stream_log and astream events implementations."""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import AsyncIterator, Iterator, TypeVar
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from typing import TypeVar
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -5,11 +5,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Generator,
|
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
@ -6,14 +6,13 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Coroutine, Sequence
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Coroutine,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
|
from collections.abc import Sequence
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||||
from typing import Any, List, Optional, Sequence, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import langsmith
|
import langsmith
|
||||||
@ -156,7 +157,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
|||||||
if isinstance(results, EvaluationResult):
|
if isinstance(results, EvaluationResult):
|
||||||
results_ = [results]
|
results_ = [results]
|
||||||
elif isinstance(results, dict) and "results" in results:
|
elif isinstance(results, dict) and "results" in results:
|
||||||
results_ = cast(List[EvaluationResult], results["results"])
|
results_ = cast(list[EvaluationResult], results["results"])
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Invalid evaluation result type {type(results)}."
|
f"Invalid evaluation result type {type(results)}."
|
||||||
|
@ -4,14 +4,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -459,7 +456,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
output: Union[dict, BaseMessage] = {}
|
output: Union[dict, BaseMessage] = {}
|
||||||
|
|
||||||
if run_info["run_type"] == "chat_model":
|
if run_info["run_type"] == "chat_model":
|
||||||
generations = cast(List[List[ChatGenerationChunk]], response.generations)
|
generations = cast(list[list[ChatGenerationChunk]], response.generations)
|
||||||
for gen in generations:
|
for gen in generations:
|
||||||
if output != {}:
|
if output != {}:
|
||||||
break
|
break
|
||||||
@ -469,7 +466,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
|||||||
|
|
||||||
event = "on_chat_model_end"
|
event = "on_chat_model_end"
|
||||||
elif run_info["run_type"] == "llm":
|
elif run_info["run_type"] == "llm":
|
||||||
generations = cast(List[List[GenerationChunk]], response.generations)
|
generations = cast(list[list[GenerationChunk]], response.generations)
|
||||||
output = {
|
output = {
|
||||||
"generations": [
|
"generations": [
|
||||||
[
|
[
|
||||||
|
@ -4,13 +4,11 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import threading
|
import threading
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
|
@ -11,7 +11,8 @@ used in the code.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import AbstractEventLoop, Queue
|
from asyncio import AbstractEventLoop, Queue
|
||||||
from typing import AsyncIterator, Generic, TypeVar
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Awaitable, Callable, Optional, Union
|
from collections.abc import Awaitable
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.runnables.config import (
|
from langchain_core.runnables.config import (
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""A tracer that collects all nested runs in a list."""
|
"""A tracer that collects all nested runs in a list."""
|
||||||
|
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
@ -38,7 +38,7 @@ class RunCollectorCallbackHandler(BaseTracer):
|
|||||||
self.example_id = (
|
self.example_id = (
|
||||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||||
)
|
)
|
||||||
self.traced_runs: List[Run] = []
|
self.traced_runs: list[Run] = []
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable
|
||||||
|
|
||||||
from langchain_core.tracers.base import BaseTracer
|
from langchain_core.tracers.base import BaseTracer
|
||||||
from langchain_core.tracers.schemas import Run
|
from langchain_core.tracers.schemas import Run
|
||||||
@ -54,7 +54,7 @@ class FunctionCallbackHandler(BaseTracer):
|
|||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_parents(self, run: Run) -> List[Run]:
|
def get_parents(self, run: Run) -> list[Run]:
|
||||||
"""Get the parents of a run.
|
"""Get the parents of a run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -5,23 +5,20 @@ MIT License
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import AbstractAsyncContextManager
|
from collections.abc import (
|
||||||
from types import TracebackType
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncContextManager,
|
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
|
||||||
Deque,
|
|
||||||
Generic,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
)
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -95,10 +92,10 @@ class NoLock:
|
|||||||
async def tee_peer(
|
async def tee_peer(
|
||||||
iterator: AsyncIterator[T],
|
iterator: AsyncIterator[T],
|
||||||
# the buffer specific to this peer
|
# the buffer specific to this peer
|
||||||
buffer: Deque[T],
|
buffer: deque[T],
|
||||||
# the buffers of all peers, including our own
|
# the buffers of all peers, including our own
|
||||||
peers: List[Deque[T]],
|
peers: list[deque[T]],
|
||||||
lock: AsyncContextManager[Any],
|
lock: AbstractAsyncContextManager[Any],
|
||||||
) -> AsyncGenerator[T, None]:
|
) -> AsyncGenerator[T, None]:
|
||||||
"""An individual iterator of a :py:func:`~.tee`.
|
"""An individual iterator of a :py:func:`~.tee`.
|
||||||
|
|
||||||
@ -191,10 +188,10 @@ class Tee(Generic[T]):
|
|||||||
iterable: AsyncIterator[T],
|
iterable: AsyncIterator[T],
|
||||||
n: int = 2,
|
n: int = 2,
|
||||||
*,
|
*,
|
||||||
lock: Optional[AsyncContextManager[Any]] = None,
|
lock: Optional[AbstractAsyncContextManager[Any]] = None,
|
||||||
):
|
):
|
||||||
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
|
||||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
|
||||||
self._children = tuple(
|
self._children = tuple(
|
||||||
tee_peer(
|
tee_peer(
|
||||||
iterator=self._iterator,
|
iterator=self._iterator,
|
||||||
@ -212,11 +209,11 @@ class Tee(Generic[T]):
|
|||||||
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
|
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, item: slice) -> Tuple[AsyncIterator[T], ...]: ...
|
def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice]
|
self, item: Union[int, slice]
|
||||||
) -> Union[AsyncIterator[T], Tuple[AsyncIterator[T], ...]]:
|
) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
|
||||||
return self._children[item]
|
return self._children[item]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
def __iter__(self) -> Iterator[AsyncIterator[T]]:
|
||||||
@ -267,7 +264,7 @@ class aclosing(AbstractAsyncContextManager):
|
|||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Optional[Type[BaseException]],
|
exc_type: Optional[type[BaseException]],
|
||||||
exc_value: Optional[BaseException],
|
exc_value: Optional[BaseException],
|
||||||
traceback: Optional[TracebackType],
|
traceback: Optional[TracebackType],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -277,7 +274,7 @@ class aclosing(AbstractAsyncContextManager):
|
|||||||
|
|
||||||
async def abatch_iterate(
|
async def abatch_iterate(
|
||||||
size: int, iterable: AsyncIterable[T]
|
size: int, iterable: AsyncIterable[T]
|
||||||
) -> AsyncIterator[List[T]]:
|
) -> AsyncIterator[list[T]]:
|
||||||
"""Utility batching function for async iterables.
|
"""Utility batching function for async iterables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -287,7 +284,7 @@ async def abatch_iterate(
|
|||||||
Returns:
|
Returns:
|
||||||
An async iterator over the batches.
|
An async iterator over the batches.
|
||||||
"""
|
"""
|
||||||
batch: List[T] = []
|
batch: list[T] = []
|
||||||
async for element in iterable:
|
async for element in iterable:
|
||||||
if len(batch) < size:
|
if len(batch) < size:
|
||||||
batch.append(element)
|
batch.append(element)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""Utilities for formatting strings."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, List, Mapping, Sequence
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class StrictFormatter(Formatter):
|
class StrictFormatter(Formatter):
|
||||||
@ -31,7 +32,7 @@ class StrictFormatter(Formatter):
|
|||||||
return super().vformat(format_string, args, kwargs)
|
return super().vformat(format_string, args, kwargs)
|
||||||
|
|
||||||
def validate_input_variables(
|
def validate_input_variables(
|
||||||
self, format_string: str, input_variables: List[str]
|
self, format_string: str, input_variables: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check that all input variables are used in the format string.
|
"""Check that all input variables are used in the format string.
|
||||||
|
|
||||||
|
@ -10,21 +10,17 @@ import typing
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated, TypedDict, get_args, get_origin, is_typeddict
|
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
@ -201,7 +197,7 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
|||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
|
|
||||||
model = cast(
|
model = cast(
|
||||||
Type[BaseModel],
|
type[BaseModel],
|
||||||
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
_convert_any_typed_dicts_to_pydantic(typed_dict, visited=visited),
|
||||||
)
|
)
|
||||||
return convert_pydantic_to_openai_function(model) # type: ignore
|
return convert_pydantic_to_openai_function(model) # type: ignore
|
||||||
@ -383,15 +379,15 @@ def convert_to_openai_function(
|
|||||||
"parameters": function,
|
"parameters": function,
|
||||||
}
|
}
|
||||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||||
oai_function = cast(Dict, convert_pydantic_to_openai_function(function))
|
oai_function = cast(dict, convert_pydantic_to_openai_function(function))
|
||||||
elif is_typeddict(function):
|
elif is_typeddict(function):
|
||||||
oai_function = cast(
|
oai_function = cast(
|
||||||
Dict, _convert_typed_dict_to_openai_function(cast(Type, function))
|
dict, _convert_typed_dict_to_openai_function(cast(type, function))
|
||||||
)
|
)
|
||||||
elif isinstance(function, BaseTool):
|
elif isinstance(function, BaseTool):
|
||||||
oai_function = cast(Dict, format_tool_to_openai_function(function))
|
oai_function = cast(dict, format_tool_to_openai_function(function))
|
||||||
elif callable(function):
|
elif callable(function):
|
||||||
oai_function = cast(Dict, convert_python_function_to_openai_function(function))
|
oai_function = cast(dict, convert_python_function_to_openai_function(function))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
||||||
@ -598,17 +594,17 @@ def _py_38_safe_origin(origin: type) -> type:
|
|||||||
)
|
)
|
||||||
|
|
||||||
origin_map: dict[type, Any] = {
|
origin_map: dict[type, Any] = {
|
||||||
dict: Dict,
|
dict: dict,
|
||||||
list: List,
|
list: list,
|
||||||
tuple: Tuple,
|
tuple: tuple,
|
||||||
set: Set,
|
set: set,
|
||||||
collections.abc.Iterable: typing.Iterable,
|
collections.abc.Iterable: typing.Iterable,
|
||||||
collections.abc.Mapping: typing.Mapping,
|
collections.abc.Mapping: typing.Mapping,
|
||||||
collections.abc.Sequence: typing.Sequence,
|
collections.abc.Sequence: typing.Sequence,
|
||||||
collections.abc.MutableMapping: typing.MutableMapping,
|
collections.abc.MutableMapping: typing.MutableMapping,
|
||||||
**origin_union_type_map,
|
**origin_union_type_map,
|
||||||
}
|
}
|
||||||
return cast(Type, origin_map.get(origin, origin))
|
return cast(type, origin_map.get(origin, origin))
|
||||||
|
|
||||||
|
|
||||||
def _recursive_set_additional_properties_false(
|
def _recursive_set_additional_properties_false(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -33,7 +34,7 @@ DEFAULT_LINK_REGEX = (
|
|||||||
|
|
||||||
def find_all_links(
|
def find_all_links(
|
||||||
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Extract all links from a raw HTML string.
|
"""Extract all links from a raw HTML string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -56,7 +57,7 @@ def extract_sub_links(
|
|||||||
prevent_outside: bool = True,
|
prevent_outside: bool = True,
|
||||||
exclude_prefixes: Sequence[str] = (),
|
exclude_prefixes: Sequence[str] = (),
|
||||||
continue_on_failure: bool = False,
|
continue_on_failure: bool = False,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Extract all links from a raw HTML string and convert into absolute paths.
|
"""Extract all links from a raw HTML string and convert into absolute paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Handle chained inputs."""
|
"""Handle chained inputs."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, TextIO
|
from typing import Optional, TextIO
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
_TEXT_COLOR_MAPPING = {
|
||||||
"blue": "36;1",
|
"blue": "36;1",
|
||||||
@ -12,8 +12,8 @@ _TEXT_COLOR_MAPPING = {
|
|||||||
|
|
||||||
|
|
||||||
def get_color_mapping(
|
def get_color_mapping(
|
||||||
items: List[str], excluded_colors: Optional[List] = None
|
items: list[str], excluded_colors: Optional[list] = None
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Get mapping for items to a support color.
|
"""Get mapping for items to a support color.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,16 +1,11 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
from collections.abc import Generator, Iterable, Iterator
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ContextManager,
|
|
||||||
Deque,
|
|
||||||
Generator,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
@ -34,10 +29,10 @@ class NoLock:
|
|||||||
def tee_peer(
|
def tee_peer(
|
||||||
iterator: Iterator[T],
|
iterator: Iterator[T],
|
||||||
# the buffer specific to this peer
|
# the buffer specific to this peer
|
||||||
buffer: Deque[T],
|
buffer: deque[T],
|
||||||
# the buffers of all peers, including our own
|
# the buffers of all peers, including our own
|
||||||
peers: List[Deque[T]],
|
peers: list[deque[T]],
|
||||||
lock: ContextManager[Any],
|
lock: AbstractContextManager[Any],
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
"""An individual iterator of a :py:func:`~.tee`.
|
"""An individual iterator of a :py:func:`~.tee`.
|
||||||
|
|
||||||
@ -130,7 +125,7 @@ class Tee(Generic[T]):
|
|||||||
iterable: Iterator[T],
|
iterable: Iterator[T],
|
||||||
n: int = 2,
|
n: int = 2,
|
||||||
*,
|
*,
|
||||||
lock: Optional[ContextManager[Any]] = None,
|
lock: Optional[AbstractContextManager[Any]] = None,
|
||||||
):
|
):
|
||||||
"""Create a new ``tee``.
|
"""Create a new ``tee``.
|
||||||
|
|
||||||
@ -141,7 +136,7 @@ class Tee(Generic[T]):
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._iterator = iter(iterable)
|
self._iterator = iter(iterable)
|
||||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
|
||||||
self._children = tuple(
|
self._children = tuple(
|
||||||
tee_peer(
|
tee_peer(
|
||||||
iterator=self._iterator,
|
iterator=self._iterator,
|
||||||
@ -159,11 +154,11 @@ class Tee(Generic[T]):
|
|||||||
def __getitem__(self, item: int) -> Iterator[T]: ...
|
def __getitem__(self, item: int) -> Iterator[T]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, item: slice) -> Tuple[Iterator[T], ...]: ...
|
def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ...
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice]
|
self, item: Union[int, slice]
|
||||||
) -> Union[Iterator[T], Tuple[Iterator[T], ...]]:
|
) -> Union[Iterator[T], tuple[Iterator[T], ...]]:
|
||||||
return self._children[item]
|
return self._children[item]
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Iterator[T]]:
|
def __iter__(self) -> Iterator[Iterator[T]]:
|
||||||
@ -185,7 +180,7 @@ class Tee(Generic[T]):
|
|||||||
safetee = Tee
|
safetee = Tee
|
||||||
|
|
||||||
|
|
||||||
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T]]:
|
def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||||
"""Utility batching function.
|
"""Utility batching function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user