diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 0fb726fa48d..8d7a6b31f01 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -14,8 +14,8 @@ import contextlib import functools import inspect import warnings -from collections.abc import Generator -from typing import Any, Callable, TypeVar, Union, cast +from collections.abc import Callable, Generator +from typing import Any, TypeVar, cast from langchain_core._api.internal import is_caller_internal @@ -27,7 +27,7 @@ class LangChainBetaWarning(DeprecationWarning): # PUBLIC API -T = TypeVar("T", bound=Union[Callable[..., Any], type]) +T = TypeVar("T", bound=Callable[..., Any] | type) def beta( diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 1253351648f..55cc493f154 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -14,13 +14,11 @@ import contextlib import functools import inspect import warnings -from collections.abc import Generator +from collections.abc import Callable, Generator from typing import ( Any, - Callable, ParamSpec, TypeVar, - Union, cast, ) @@ -42,7 +40,7 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): # Last Any should be FieldInfoV1 but this leads to circular imports -T = TypeVar("T", bound=Union[type, Callable[..., Any], Any]) +T = TypeVar("T", bound=type | Callable[..., Any] | Any) def _validate_deprecation_params( @@ -276,7 +274,7 @@ def deprecated( if not _obj_type: _obj_type = "attribute" wrapped = None - _name = _name or cast("Union[type, Callable]", obj.fget).__qualname__ + _name = _name or cast("type | Callable", obj.fget).__qualname__ old_doc = obj.__doc__ class _DeprecatedProperty(property): @@ -284,19 +282,17 @@ def deprecated( def __init__( self, - fget: Union[Callable[[Any], Any], None] = None, - fset: Union[Callable[[Any, Any], None], None] = None, - fdel: Union[Callable[[Any], None], None] = None, - doc: Union[str, None] = None, + fget: Callable[[Any], Any] | None = None, + fset: Callable[[Any, Any], None] | None = None, + fdel: Callable[[Any], None] | None = None, + doc: str | None = None, ) -> None: super().__init__(fget, fset, fdel, doc) self.__orig_fget = fget self.__orig_fset = fset self.__orig_fdel = fdel - def __get__( - self, instance: Any, owner: Union[type, None] = None - ) -> Any: + def __get__(self, instance: Any, owner: type | None = None) -> Any: if instance is not None or owner is not None: emit_warning() if self.fget is None: @@ -315,7 +311,7 @@ def deprecated( if self.fdel is not None: self.fdel(instance) - def __set_name__(self, owner: Union[type, None], set_name: str) -> None: + def __set_name__(self, owner: type | None, set_name: str) -> None: nonlocal _name if _name == "": _name = set_name @@ -330,7 +326,7 @@ def deprecated( ) else: - _name = _name or cast("Union[type, Callable]", obj).__qualname__ + _name = _name or cast("type | Callable", obj).__qualname__ if not _obj_type: # edge case: when a function is within another function # within a test, this will call it a "method" not a "function" diff --git a/libs/core/langchain_core/_api/path.py b/libs/core/langchain_core/_api/path.py index 0d70b5a23a3..5b597523eb9 100644 --- a/libs/core/langchain_core/_api/path.py +++ b/libs/core/langchain_core/_api/path.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import Optional, Union HERE = Path(__file__).parent @@ -9,9 +8,7 @@ PACKAGE_DIR = HERE.parent SEPARATOR = os.sep -def get_relative_path( - file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR -) -> str: +def get_relative_path(file: Path | str, *, relative_to: Path = PACKAGE_DIR) -> str: """Get the path of the file as a relative path to the package directory. Args: @@ -27,9 +24,9 @@ def get_relative_path( def as_import_path( - file: Union[Path, str], + file: Path | str, *, - suffix: Optional[str] = None, + suffix: str | None = None, relative_to: Path = PACKAGE_DIR, ) -> str: """Path of the file as a LangChain import exclude langchain top namespace. diff --git a/libs/core/langchain_core/_import_utils.py b/libs/core/langchain_core/_import_utils.py index 6e0ce0bbee9..01c2e97f42b 100644 --- a/libs/core/langchain_core/_import_utils.py +++ b/libs/core/langchain_core/_import_utils.py @@ -1,11 +1,10 @@ from importlib import import_module -from typing import Union def import_attr( attr_name: str, - module_name: Union[str, None], - package: Union[str, None], + module_name: str | None, + package: str | None, ) -> object: """Import an attribute from a module located in a package. diff --git a/libs/core/langchain_core/agents.py b/libs/core/langchain_core/agents.py index a111e5c6d10..dcd64c0f601 100644 --- a/libs/core/langchain_core/agents.py +++ b/libs/core/langchain_core/agents.py @@ -29,7 +29,7 @@ from __future__ import annotations import json from collections.abc import Sequence -from typing import Any, Literal, Union +from typing import Any, Literal from langchain_core.load.serializable import Serializable from langchain_core.messages import ( @@ -49,7 +49,7 @@ class AgentAction(Serializable): tool: str """The name of the Tool to execute.""" - tool_input: Union[str, dict] + tool_input: str | dict """The input to pass in to the Tool.""" log: str """Additional information to log about the action. @@ -62,9 +62,7 @@ class AgentAction(Serializable): type: Literal["AgentAction"] = "AgentAction" # Override init to support instantiation by position for backward compat. - def __init__( - self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any - ): + def __init__(self, tool: str, tool_input: str | dict, log: str, **kwargs: Any): """Create an AgentAction. Args: diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index b285ba115aa..12bb1bd6fb2 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -25,7 +25,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -52,7 +52,7 @@ class BaseCache(ABC): """ @abstractmethod - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Look up based on prompt and llm_string. A cache implementation is expected to generate a key from the 2-tuple @@ -97,7 +97,7 @@ class BaseCache(ABC): def clear(self, **kwargs: Any) -> None: """Clear cache that can take additional keyword arguments.""" - async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + async def alookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Async look up based on prompt and llm_string. A cache implementation is expected to generate a key from the 2-tuple @@ -149,7 +149,7 @@ class BaseCache(ABC): class InMemoryCache(BaseCache): """Cache that stores things in memory.""" - def __init__(self, *, maxsize: Optional[int] = None) -> None: + def __init__(self, *, maxsize: int | None = None) -> None: """Initialize with empty cache. Args: @@ -167,7 +167,7 @@ class InMemoryCache(BaseCache): raise ValueError(msg) self._maxsize = maxsize - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Look up based on prompt and llm_string. Args: @@ -201,7 +201,7 @@ class InMemoryCache(BaseCache): """Clear cache.""" self._cache = {} - async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + async def alookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Async look up based on prompt and llm_string. Args: diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 4b82a9dab84..6d3cf7b9629 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -29,7 +29,7 @@ class RetrieverManagerMixin: error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when Retriever errors. @@ -46,7 +46,7 @@ class RetrieverManagerMixin: documents: Sequence[Document], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when Retriever ends running. @@ -66,9 +66,9 @@ class LLMManagerMixin: self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on new output token. Only available when streaming is enabled. @@ -89,7 +89,7 @@ class LLMManagerMixin: response: LLMResult, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when LLM ends running. @@ -106,7 +106,7 @@ class LLMManagerMixin: error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when LLM errors. @@ -127,7 +127,7 @@ class ChainManagerMixin: outputs: dict[str, Any], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when chain ends running. @@ -144,7 +144,7 @@ class ChainManagerMixin: error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when chain errors. @@ -161,7 +161,7 @@ class ChainManagerMixin: action: AgentAction, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on agent action. @@ -178,7 +178,7 @@ class ChainManagerMixin: finish: AgentFinish, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on the agent end. @@ -199,7 +199,7 @@ class ToolManagerMixin: output: Any, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when the tool ends running. @@ -216,7 +216,7 @@ class ToolManagerMixin: error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run when tool errors. @@ -238,9 +238,9 @@ class CallbackManagerMixin: prompts: list[str], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when LLM starts running. @@ -266,9 +266,9 @@ class CallbackManagerMixin: messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running. @@ -297,9 +297,9 @@ class CallbackManagerMixin: query: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when the Retriever starts running. @@ -320,9 +320,9 @@ class CallbackManagerMixin: inputs: dict[str, Any], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when a chain starts running. @@ -343,10 +343,10 @@ class CallbackManagerMixin: input_str: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - inputs: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when the tool starts running. @@ -371,7 +371,7 @@ class RunManagerMixin: text: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on an arbitrary text. @@ -388,7 +388,7 @@ class RunManagerMixin: retry_state: RetryCallState, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on a retry event. @@ -406,8 +406,8 @@ class RunManagerMixin: data: Any, *, run_id: UUID, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Override to define a handler for a custom event. @@ -487,9 +487,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): prompts: list[str], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Run when the model starts running. @@ -515,9 +515,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running. @@ -544,10 +544,10 @@ class AsyncCallbackHandler(BaseCallbackHandler): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on new output token. Only available when streaming is enabled. @@ -569,8 +569,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): response: LLMResult, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when the model ends running. @@ -588,8 +588,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when LLM errors. @@ -610,9 +610,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): inputs: dict[str, Any], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Run when a chain starts running. @@ -632,8 +632,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): outputs: dict[str, Any], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when a chain ends running. @@ -651,8 +651,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when chain errors. @@ -671,10 +671,10 @@ class AsyncCallbackHandler(BaseCallbackHandler): input_str: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - inputs: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Run when the tool starts running. @@ -695,8 +695,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): output: Any, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when the tool ends running. @@ -714,8 +714,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when tool errors. @@ -733,8 +733,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): text: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on an arbitrary text. @@ -752,7 +752,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): retry_state: RetryCallState, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: """Run on a retry event. @@ -769,8 +769,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): action: AgentAction, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on agent action. @@ -788,8 +788,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): finish: AgentFinish, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on the agent end. @@ -808,9 +808,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): query: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Run on the retriever start. @@ -830,8 +830,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): documents: Sequence[Document], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on the retriever end. @@ -849,8 +849,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run on retriever error. @@ -869,8 +869,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): data: Any, *, run_id: UUID, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Override to define a handler for custom events. @@ -895,13 +895,13 @@ class BaseCallbackManager(CallbackManagerMixin): def __init__( self, handlers: list[BaseCallbackHandler], - inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, + inheritable_handlers: list[BaseCallbackHandler] | None = None, + parent_run_id: UUID | None = None, *, - tags: Optional[list[str]] = None, - inheritable_tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - inheritable_metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + inheritable_tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inheritable_metadata: dict[str, Any] | None = None, ) -> None: """Initialize callback manager. @@ -921,7 +921,7 @@ class BaseCallbackManager(CallbackManagerMixin): self.inheritable_handlers: list[BaseCallbackHandler] = ( inheritable_handlers or [] ) - self.parent_run_id: Optional[UUID] = parent_run_id + self.parent_run_id: UUID | None = parent_run_id self.tags = tags or [] self.inheritable_tags = inheritable_tags or [] self.metadata = metadata or {} @@ -1115,4 +1115,4 @@ class BaseCallbackManager(CallbackManagerMixin): self.inheritable_metadata.pop(key, None) -Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]] +Callbacks = list[BaseCallbackHandler] | BaseCallbackManager | None diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index 7f1c8423039..bdc12f5f548 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -3,7 +3,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TextIO, cast +from typing import TYPE_CHECKING, Any, TextIO, cast from typing_extensions import Self, override @@ -57,7 +57,7 @@ class FileCallbackHandler(BaseCallbackHandler): """ def __init__( - self, filename: str, mode: str = "a", color: Optional[str] = None + self, filename: str, mode: str = "a", color: str | None = None ) -> None: """Initialize the file callback handler. @@ -124,7 +124,7 @@ class FileCallbackHandler(BaseCallbackHandler): def _write( self, text: str, - color: Optional[str] = None, + color: str | None = None, end: str = "", ) -> None: """Write text to the file with deprecation warning if needed. @@ -190,7 +190,7 @@ class FileCallbackHandler(BaseCallbackHandler): @override def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + self, action: AgentAction, color: str | None = None, **kwargs: Any ) -> Any: """Handle agent action by writing the action log. @@ -207,9 +207,9 @@ class FileCallbackHandler(BaseCallbackHandler): def on_tool_end( self, output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, + color: str | None = None, + observation_prefix: str | None = None, + llm_prefix: str | None = None, **kwargs: Any, ) -> None: """Handle tool end by writing the output with optional prefixes. @@ -231,7 +231,7 @@ class FileCallbackHandler(BaseCallbackHandler): @override def on_text( - self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Any + self, text: str, color: str | None = None, end: str = "", **kwargs: Any ) -> None: """Handle text output. @@ -247,7 +247,7 @@ class FileCallbackHandler(BaseCallbackHandler): @override def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + self, finish: AgentFinish, color: str | None = None, **kwargs: Any ) -> None: """Handle agent finish by writing the finish log. diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index e9d40b94688..72d4e5a79cd 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -8,10 +8,11 @@ import functools import logging import uuid from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import copy_context -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from uuid import UUID from langsmith.run_helpers import get_tracing_context @@ -62,14 +63,14 @@ def _get_debug() -> bool: @contextmanager def trace_as_chain_group( group_name: str, - callback_manager: Optional[CallbackManager] = None, + callback_manager: CallbackManager | None = None, *, - inputs: Optional[dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, + project_name: str | None = None, + example_id: str | UUID | None = None, + run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, ) -> Generator[CallbackManagerForChainGroup, None, None]: """Get a callback manager for a chain group in a context manager. @@ -146,14 +147,14 @@ def trace_as_chain_group( @asynccontextmanager async def atrace_as_chain_group( group_name: str, - callback_manager: Optional[AsyncCallbackManager] = None, + callback_manager: AsyncCallbackManager | None = None, *, - inputs: Optional[dict[str, Any]] = None, - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, + project_name: str | None = None, + example_id: str | UUID | None = None, + run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: """Get an async callback manager for a chain group in a context manager. @@ -251,7 +252,7 @@ def shielded(func: Func) -> Func: def handle_event( handlers: list[BaseCallbackHandler], event_name: str, - ignore_condition_name: Optional[str], + ignore_condition_name: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -272,7 +273,7 @@ def handle_event( coros: list[Coroutine[Any, Any, Any]] = [] try: - message_strings: Optional[list[str]] = None + message_strings: list[str] | None = None for handler in handlers: try: if ignore_condition_name is None or not getattr( @@ -366,7 +367,7 @@ def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None: async def _ahandle_event_for_handler( handler: BaseCallbackHandler, event_name: str, - ignore_condition_name: Optional[str], + ignore_condition_name: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -418,7 +419,7 @@ async def _ahandle_event_for_handler( async def ahandle_event( handlers: list[BaseCallbackHandler], event_name: str, - ignore_condition_name: Optional[str], + ignore_condition_name: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -464,11 +465,11 @@ class BaseRunManager(RunManagerMixin): run_id: UUID, handlers: list[BaseCallbackHandler], inheritable_handlers: list[BaseCallbackHandler], - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - inheritable_tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - inheritable_metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + inheritable_tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inheritable_metadata: dict[str, Any] | None = None, ) -> None: """Initialize the run manager. @@ -572,7 +573,7 @@ class RunManager(BaseRunManager): class ParentRunManager(RunManager): """Sync Parent Run Manager.""" - def get_child(self, tag: Optional[str] = None) -> CallbackManager: + def get_child(self, tag: str | None = None) -> CallbackManager: """Get a child callback manager. Args: @@ -657,7 +658,7 @@ class AsyncRunManager(BaseRunManager, ABC): class AsyncParentRunManager(AsyncRunManager): """Async Parent Run Manager.""" - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: + def get_child(self, tag: str | None = None) -> AsyncCallbackManager: """Get a child callback manager. Args: @@ -684,7 +685,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -783,7 +784,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, ) -> None: """Run when LLM generates a new token. @@ -865,7 +866,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): """Callback manager for chain run.""" - def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any] | Any, **kwargs: Any) -> None: """Run when chain ends running. Args: @@ -973,9 +974,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): ) @shielded - async def on_chain_end( - self, outputs: Union[dict[str, Any], Any], **kwargs: Any - ) -> None: + async def on_chain_end(self, outputs: dict[str, Any] | Any, **kwargs: Any) -> None: """Run when a chain ends running. Args: @@ -1320,7 +1319,7 @@ class CallbackManager(BaseCallbackManager): self, serialized: dict[str, Any], prompts: list[str], - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> list[CallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1372,7 +1371,7 @@ class CallbackManager(BaseCallbackManager): self, serialized: dict[str, Any], messages: list[list[BaseMessage]], - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> list[CallbackManagerForLLMRun]: """Run when chat model starts running. @@ -1425,9 +1424,9 @@ class CallbackManager(BaseCallbackManager): def on_chain_start( self, - serialized: Optional[dict[str, Any]], - inputs: Union[dict[str, Any], Any], - run_id: Optional[UUID] = None, + serialized: dict[str, Any] | None, + inputs: dict[str, Any] | Any, + run_id: UUID | None = None, **kwargs: Any, ) -> CallbackManagerForChainRun: """Run when chain starts running. @@ -1471,11 +1470,11 @@ class CallbackManager(BaseCallbackManager): @override def on_tool_start( self, - serialized: Optional[dict[str, Any]], + serialized: dict[str, Any] | None, input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, - inputs: Optional[dict[str, Any]] = None, + run_id: UUID | None = None, + parent_run_id: UUID | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> CallbackManagerForToolRun: """Run when tool starts running. @@ -1528,10 +1527,10 @@ class CallbackManager(BaseCallbackManager): @override def on_retriever_start( self, - serialized: Optional[dict[str, Any]], + serialized: dict[str, Any] | None, query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, + run_id: UUID | None = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> CallbackManagerForRetrieverRun: """Run when the retriever starts running. @@ -1577,7 +1576,7 @@ class CallbackManager(BaseCallbackManager): self, name: str, data: Any, - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> None: """Dispatch an adhoc event to the handlers (async version). @@ -1626,10 +1625,10 @@ class CallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, # noqa: FBT001,FBT002 - inheritable_tags: Optional[list[str]] = None, - local_tags: Optional[list[str]] = None, - inheritable_metadata: Optional[dict[str, Any]] = None, - local_metadata: Optional[dict[str, Any]] = None, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, Any] | None = None, + local_metadata: dict[str, Any] | None = None, ) -> CallbackManager: """Configure the callback manager. @@ -1670,8 +1669,8 @@ class CallbackManagerForChainGroup(CallbackManager): def __init__( self, handlers: list[BaseCallbackHandler], - inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, + inheritable_handlers: list[BaseCallbackHandler] | None = None, + parent_run_id: UUID | None = None, *, parent_run_manager: CallbackManagerForChainRun, **kwargs: Any, @@ -1775,7 +1774,7 @@ class CallbackManagerForChainGroup(CallbackManager): manager.add_handler(handler, inherit=True) return manager - def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any] | Any, **kwargs: Any) -> None: """Run when traced chain group ends. Args: @@ -1814,7 +1813,7 @@ class AsyncCallbackManager(BaseCallbackManager): self, serialized: dict[str, Any], prompts: list[str], - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> list[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1903,7 +1902,7 @@ class AsyncCallbackManager(BaseCallbackManager): self, serialized: dict[str, Any], messages: list[list[BaseMessage]], - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> list[AsyncCallbackManagerForLLMRun]: """Async run when LLM starts running. @@ -1973,9 +1972,9 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_chain_start( self, - serialized: Optional[dict[str, Any]], - inputs: Union[dict[str, Any], Any], - run_id: Optional[UUID] = None, + serialized: dict[str, Any] | None, + inputs: dict[str, Any] | Any, + run_id: UUID | None = None, **kwargs: Any, ) -> AsyncCallbackManagerForChainRun: """Async run when chain starts running. @@ -2020,10 +2019,10 @@ class AsyncCallbackManager(BaseCallbackManager): @override async def on_tool_start( self, - serialized: Optional[dict[str, Any]], + serialized: dict[str, Any] | None, input_str: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, + run_id: UUID | None = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> AsyncCallbackManagerForToolRun: """Run when the tool starts running. @@ -2071,7 +2070,7 @@ class AsyncCallbackManager(BaseCallbackManager): self, name: str, data: Any, - run_id: Optional[UUID] = None, + run_id: UUID | None = None, **kwargs: Any, ) -> None: """Dispatch an adhoc event to the handlers (async version). @@ -2116,10 +2115,10 @@ class AsyncCallbackManager(BaseCallbackManager): @override async def on_retriever_start( self, - serialized: Optional[dict[str, Any]], + serialized: dict[str, Any] | None, query: str, - run_id: Optional[UUID] = None, - parent_run_id: Optional[UUID] = None, + run_id: UUID | None = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> AsyncCallbackManagerForRetrieverRun: """Run when the retriever starts running. @@ -2168,10 +2167,10 @@ class AsyncCallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, # noqa: FBT001,FBT002 - inheritable_tags: Optional[list[str]] = None, - local_tags: Optional[list[str]] = None, - inheritable_metadata: Optional[dict[str, Any]] = None, - local_metadata: Optional[dict[str, Any]] = None, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, Any] | None = None, + local_metadata: dict[str, Any] | None = None, ) -> AsyncCallbackManager: """Configure the async callback manager. @@ -2211,8 +2210,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): def __init__( self, handlers: list[BaseCallbackHandler], - inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, - parent_run_id: Optional[UUID] = None, + inheritable_handlers: list[BaseCallbackHandler] | None = None, + parent_run_id: UUID | None = None, *, parent_run_manager: AsyncCallbackManagerForChainRun, **kwargs: Any, @@ -2316,9 +2315,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): manager.add_handler(handler, inherit=True) return manager - async def on_chain_end( - self, outputs: Union[dict[str, Any], Any], **kwargs: Any - ) -> None: + async def on_chain_end(self, outputs: dict[str, Any] | Any, **kwargs: Any) -> None: """Run when traced chain group ends. Args: @@ -2350,10 +2347,10 @@ def _configure( callback_manager_cls: type[T], inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, - inheritable_tags: Optional[list[str]] = None, - local_tags: Optional[list[str]] = None, - inheritable_metadata: Optional[dict[str, Any]] = None, - local_metadata: Optional[dict[str, Any]] = None, + inheritable_tags: list[str] | None = None, + local_tags: list[str] | None = None, + inheritable_metadata: dict[str, Any] | None = None, + local_metadata: dict[str, Any] | None = None, *, verbose: bool = False, ) -> T: @@ -2383,7 +2380,7 @@ def _configure( tracing_context = get_tracing_context() tracing_metadata = tracing_context["metadata"] tracing_tags = tracing_context["tags"] - run_tree: Optional[Run] = tracing_context["parent"] + run_tree: Run | None = tracing_context["parent"] parent_run_id = None if run_tree is None else run_tree.id callback_manager = callback_manager_cls( handlers=[], @@ -2528,7 +2525,7 @@ def _configure( async def adispatch_custom_event( - name: str, data: Any, *, config: Optional[RunnableConfig] = None + name: str, data: Any, *, config: RunnableConfig | None = None ) -> None: """Dispatch an adhoc event to the handlers. @@ -2654,7 +2651,7 @@ async def adispatch_custom_event( def dispatch_custom_event( - name: str, data: Any, *, config: Optional[RunnableConfig] = None + name: str, data: Any, *, config: RunnableConfig | None = None ) -> None: """Dispatch an adhoc event. diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index f72ea6564f9..e240cd9f245 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -16,7 +16,7 @@ if TYPE_CHECKING: class StdOutCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" - def __init__(self, color: Optional[str] = None) -> None: + def __init__(self, color: str | None = None) -> None: """Initialize callback handler. Args: @@ -55,7 +55,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): @override def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + self, action: AgentAction, color: str | None = None, **kwargs: Any ) -> Any: """Run on agent action. @@ -70,9 +70,9 @@ class StdOutCallbackHandler(BaseCallbackHandler): def on_tool_end( self, output: Any, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, + color: str | None = None, + observation_prefix: str | None = None, + llm_prefix: str | None = None, **kwargs: Any, ) -> None: """If not the final action, print out observation. @@ -96,7 +96,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): def on_text( self, text: str, - color: Optional[str] = None, + color: str | None = None, end: str = "", **kwargs: Any, ) -> None: @@ -112,7 +112,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): @override def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + self, finish: AgentFinish, color: str | None = None, **kwargs: Any ) -> None: """Run on the agent end. diff --git a/libs/core/langchain_core/callbacks/usage.py b/libs/core/langchain_core/callbacks/usage.py index 56217c25a05..f498c6c53f1 100644 --- a/libs/core/langchain_core/callbacks/usage.py +++ b/libs/core/langchain_core/callbacks/usage.py @@ -4,7 +4,7 @@ import threading from collections.abc import Generator from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -134,7 +134,7 @@ def get_usage_metadata_callback( !!! version-added "Added in version 0.3.49" """ - usage_metadata_callback_var: ContextVar[Optional[UsageMetadataCallbackHandler]] = ( + usage_metadata_callback_var: ContextVar[UsageMetadataCallbackHandler | None] = ( ContextVar(name, default=None) ) register_configure_hook(usage_metadata_callback_var, inheritable=True) diff --git a/libs/core/langchain_core/document_loaders/base.py b/libs/core/langchain_core/document_loaders/base.py index deeb9569ad2..6f46b99b71c 100644 --- a/libs/core/langchain_core/document_loaders/base.py +++ b/libs/core/langchain_core/document_loaders/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from langchain_core.runnables import run_in_executor @@ -51,7 +51,7 @@ class BaseLoader(ABC): # noqa: B024 return [document async for document in self.alazy_load()] def load_and_split( - self, text_splitter: Optional[TextSplitter] = None + self, text_splitter: TextSplitter | None = None ) -> list[Document]: """Load Documents and split into chunks. Chunks are returned as Documents. diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 1871bc18f55..0c55791dc7c 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -3,8 +3,8 @@ import datetime import json import uuid -from collections.abc import Iterator, Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterator, Sequence +from typing import Any from langsmith import Client as LangSmithClient from typing_extensions import override @@ -42,19 +42,19 @@ class LangSmithLoader(BaseLoader): def __init__( self, *, - dataset_id: Optional[Union[uuid.UUID, str]] = None, - dataset_name: Optional[str] = None, - example_ids: Optional[Sequence[Union[uuid.UUID, str]]] = None, - as_of: Optional[Union[datetime.datetime, str]] = None, - splits: Optional[Sequence[str]] = None, + dataset_id: uuid.UUID | str | None = None, + dataset_name: str | None = None, + example_ids: Sequence[uuid.UUID | str] | None = None, + as_of: datetime.datetime | str | None = None, + splits: Sequence[str] | None = None, inline_s3_urls: bool = True, offset: int = 0, - limit: Optional[int] = None, - metadata: Optional[dict] = None, - filter: Optional[str] = None, # noqa: A002 + limit: int | None = None, + metadata: dict | None = None, + filter: str | None = None, # noqa: A002 content_key: str = "", - format_content: Optional[Callable[..., str]] = None, - client: Optional[LangSmithClient] = None, + format_content: Callable[..., str] | None = None, + client: LangSmithClient | None = None, **client_kwargs: Any, ) -> None: """Create a LangSmith loader. @@ -129,7 +129,7 @@ class LangSmithLoader(BaseLoader): yield Document(content_str, metadata=metadata) -def _stringify(x: Union[str, dict]) -> str: +def _stringify(x: str | dict) -> str: if isinstance(x, str): return x try: diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 1ad56770fa1..1a14218ca42 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -6,7 +6,7 @@ import contextlib import mimetypes from io import BufferedReader, BytesIO from pathlib import Path, PurePath -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic import ConfigDict, Field, model_validator @@ -15,7 +15,7 @@ from langchain_core.load.serializable import Serializable if TYPE_CHECKING: from collections.abc import Generator -PathLike = Union[str, PurePath] +PathLike = str | PurePath class BaseMedia(Serializable): @@ -33,7 +33,7 @@ class BaseMedia(Serializable): # The ID field is optional at the moment. # It will likely become required in a future major release after # it has been adopted by enough vectorstore implementations. - id: Optional[str] = Field(default=None, coerce_numbers_to_str=True) + id: str | None = Field(default=None, coerce_numbers_to_str=True) """An optional identifier for the document. Ideally this should be unique across the document collection and formatted @@ -105,16 +105,16 @@ class Blob(BaseMedia): """ - data: Union[bytes, str, None] = None + data: bytes | str | None = None """Raw data associated with the blob.""" - mimetype: Optional[str] = None + mimetype: str | None = None """MimeType not to be confused with a file extension.""" encoding: str = "utf-8" """Encoding to use if decoding the bytes into a string. Use utf-8 as default encoding, if decoding to string. """ - path: Optional[PathLike] = None + path: PathLike | None = None """Location where the original content was found.""" model_config = ConfigDict( @@ -123,7 +123,7 @@ class Blob(BaseMedia): ) @property - def source(self) -> Optional[str]: + def source(self) -> str | None: """The source location of the blob as string if known otherwise none. If a path is associated with the blob, it will default to the path location. @@ -132,7 +132,7 @@ class Blob(BaseMedia): case that value will be used instead. """ if self.metadata and "source" in self.metadata: - return cast("Optional[str]", self.metadata["source"]) + return cast("str | None", self.metadata["source"]) return str(self.path) if self.path else None @model_validator(mode="before") @@ -181,7 +181,7 @@ class Blob(BaseMedia): raise ValueError(msg) @contextlib.contextmanager - def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: + def as_bytes_io(self) -> Generator[BytesIO | BufferedReader, None, None]: """Read data as a byte stream. Raises: @@ -205,9 +205,9 @@ class Blob(BaseMedia): path: PathLike, *, encoding: str = "utf-8", - mime_type: Optional[str] = None, + mime_type: str | None = None, guess_type: bool = True, - metadata: Optional[dict] = None, + metadata: dict | None = None, ) -> Blob: """Load the blob from a path like object. @@ -239,12 +239,12 @@ class Blob(BaseMedia): @classmethod def from_data( cls, - data: Union[str, bytes], + data: str | bytes, *, encoding: str = "utf-8", - mime_type: Optional[str] = None, - path: Optional[str] = None, - metadata: Optional[dict] = None, + mime_type: str | None = None, + path: str | None = None, + metadata: dict | None = None, ) -> Blob: """Initialize the blob from in-memory data. diff --git a/libs/core/langchain_core/documents/compressor.py b/libs/core/langchain_core/documents/compressor.py index 84c58b87018..b18728eb9d3 100644 --- a/libs/core/langchain_core/documents/compressor.py +++ b/libs/core/langchain_core/documents/compressor.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -38,7 +38,7 @@ class BaseDocumentCompressor(BaseModel, ABC): self, documents: Sequence[Document], query: str, - callbacks: Optional[Callbacks] = None, + callbacks: Callbacks | None = None, ) -> Sequence[Document]: """Compress retrieved documents given the query context. @@ -56,7 +56,7 @@ class BaseDocumentCompressor(BaseModel, ABC): self, documents: Sequence[Document], query: str, - callbacks: Optional[Callbacks] = None, + callbacks: Callbacks | None = None, ) -> Sequence[Document]: """Async compress retrieved documents given the query context. diff --git a/libs/core/langchain_core/example_selectors/length_based.py b/libs/core/langchain_core/example_selectors/length_based.py index ec9566d75ac..296db6c1c60 100644 --- a/libs/core/langchain_core/example_selectors/length_based.py +++ b/libs/core/langchain_core/example_selectors/length_based.py @@ -1,7 +1,7 @@ """Select examples based on length.""" import re -from typing import Callable +from collections.abc import Callable from pydantic import BaseModel, Field, model_validator from typing_extensions import Self diff --git a/libs/core/langchain_core/example_selectors/semantic_similarity.py b/libs/core/langchain_core/example_selectors/semantic_similarity.py index b0362728aaa..57c3ece590a 100644 --- a/libs/core/langchain_core/example_selectors/semantic_similarity.py +++ b/libs/core/langchain_core/example_selectors/semantic_similarity.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict @@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): """VectorStore that contains information about examples.""" k: int = 4 """Number of examples to select.""" - example_keys: Optional[list[str]] = None + example_keys: list[str] | None = None """Optional keys to filter examples to.""" - input_keys: Optional[list[str]] = None + input_keys: list[str] | None = None """Optional keys to filter input to. If provided, the search is based on the input variables instead of all variables.""" - vectorstore_kwargs: Optional[dict[str, Any]] = None + vectorstore_kwargs: dict[str, Any] | None = None """Extra arguments passed to similarity_search function of the vectorstore.""" model_config = ConfigDict( @@ -49,9 +49,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): ) @staticmethod - def _example_to_text( - example: dict[str, str], input_keys: Optional[list[str]] - ) -> str: + def _example_to_text(example: dict[str, str], input_keys: list[str] | None) -> str: if input_keys: return " ".join(sorted_values({key: example[key] for key in input_keys})) return " ".join(sorted_values(example)) @@ -142,10 +140,10 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): embeddings: Embeddings, vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[list[str]] = None, + input_keys: list[str] | None = None, *, - example_keys: Optional[list[str]] = None, - vectorstore_kwargs: Optional[dict] = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, **vectorstore_cls_kwargs: Any, ) -> SemanticSimilarityExampleSelector: """Create k-shot example selector using example list and embeddings. @@ -186,10 +184,10 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): embeddings: Embeddings, vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[list[str]] = None, + input_keys: list[str] | None = None, *, - example_keys: Optional[list[str]] = None, - vectorstore_kwargs: Optional[dict] = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, **vectorstore_cls_kwargs: Any, ) -> SemanticSimilarityExampleSelector: """Async create k-shot example selector using example list and embeddings. @@ -273,10 +271,10 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): embeddings: Embeddings, vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[list[str]] = None, + input_keys: list[str] | None = None, fetch_k: int = 20, - example_keys: Optional[list[str]] = None, - vectorstore_kwargs: Optional[dict] = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: """Create k-shot example selector using example list and embeddings. @@ -321,10 +319,10 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): vectorstore_cls: type[VectorStore], *, k: int = 4, - input_keys: Optional[list[str]] = None, + input_keys: list[str] | None = None, fetch_k: int = 20, - example_keys: Optional[list[str]] = None, - vectorstore_kwargs: Optional[dict] = None, + example_keys: list[str] | None = None, + vectorstore_kwargs: dict | None = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: """Create k-shot example selector using example list and embeddings. diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index 6aa8d89bc1f..6816a395dc7 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -1,7 +1,7 @@ """Custom **exceptions** for LangChain.""" from enum import Enum -from typing import Any, Optional +from typing import Any class LangChainException(Exception): # noqa: N818 @@ -24,8 +24,8 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818 def __init__( self, error: Any, - observation: Optional[str] = None, - llm_output: Optional[str] = None, + observation: str | None = None, + llm_output: str | None = None, send_to_llm: bool = False, # noqa: FBT001,FBT002 ): """Create an OutputParserException. diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 7e866107ae0..2b5fe719ae6 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -6,16 +6,20 @@ import hashlib import json import uuid import warnings -from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator, Sequence +from collections.abc import ( + AsyncIterable, + AsyncIterator, + Callable, + Iterable, + Iterator, + Sequence, +) from itertools import islice from typing import ( Any, - Callable, Literal, - Optional, TypedDict, TypeVar, - Union, cast, ) @@ -107,8 +111,8 @@ async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T def _get_source_id_assigner( - source_id_key: Union[str, Callable[[Document], str], None], -) -> Callable[[Document], Union[str, None]]: + source_id_key: str | Callable[[Document], str] | None, +) -> Callable[[Document], str | None]: """Get the source id from the document.""" if source_id_key is None: return lambda _doc: None @@ -162,9 +166,8 @@ def _calculate_hash( def _get_document_with_hash( document: Document, *, - key_encoder: Union[ - Callable[[Document], str], Literal["sha1", "sha256", "sha512", "blake2b"] - ], + key_encoder: Callable[[Document], str] + | Literal["sha1", "sha256", "sha512", "blake2b"], ) -> Document: """Calculate a hash of the document, and assign it to the uid. @@ -233,7 +236,7 @@ class _HashedDocument: def _delete( - vector_store: Union[VectorStore, DocumentIndex], + vector_store: VectorStore | DocumentIndex, ids: list[str], ) -> None: if isinstance(vector_store, VectorStore): @@ -271,19 +274,18 @@ class IndexingResult(TypedDict): def index( - docs_source: Union[BaseLoader, Iterable[Document]], + docs_source: BaseLoader | Iterable[Document], record_manager: RecordManager, - vector_store: Union[VectorStore, DocumentIndex], + vector_store: VectorStore | DocumentIndex, *, batch_size: int = 100, - cleanup: Optional[Literal["incremental", "full", "scoped_full"]] = None, - source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup: Literal["incremental", "full", "scoped_full"] | None = None, + source_id_key: str | Callable[[Document], str] | None = None, cleanup_batch_size: int = 1_000, force_update: bool = False, - key_encoder: Union[ - Literal["sha1", "sha256", "sha512", "blake2b"], Callable[[Document], str] - ] = "sha1", - upsert_kwargs: Optional[dict[str, Any]] = None, + key_encoder: Literal["sha1", "sha256", "sha512", "blake2b"] + | Callable[[Document], str] = "sha1", + upsert_kwargs: dict[str, Any] | None = None, ) -> IndexingResult: """Index data from the loader into the vector store. @@ -462,13 +464,13 @@ def index( # Count documents removed by within-batch deduplication num_skipped += original_batch_size - len(hashed_docs) - source_ids: Sequence[Optional[str]] = [ + source_ids: Sequence[str | None] = [ source_id_assigner(hashed_doc) for hashed_doc in hashed_docs ] if cleanup in {"incremental", "scoped_full"}: # source ids are required. - for source_id, hashed_doc in zip(source_ids, hashed_docs): + for source_id, hashed_doc in zip(source_ids, hashed_docs, strict=False): if source_id is None: msg = ( f"Source ids are required when cleanup mode is " @@ -492,7 +494,7 @@ def index( docs_to_index = [] uids_to_refresh = [] seen_docs: set[str] = set() - for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch, strict=False): hashed_id = cast("str", hashed_doc.id) if doc_exists: if force_update: @@ -563,7 +565,7 @@ def index( if cleanup == "full" or ( cleanup == "scoped_full" and scoped_full_cleanup_source_ids ): - delete_group_ids: Optional[Sequence[str]] = None + delete_group_ids: Sequence[str] | None = None if cleanup == "scoped_full": delete_group_ids = list(scoped_full_cleanup_source_ids) while uids_to_delete := record_manager.list_keys( @@ -591,7 +593,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]: async def _adelete( - vector_store: Union[VectorStore, DocumentIndex], + vector_store: VectorStore | DocumentIndex, ids: list[str], ) -> None: if isinstance(vector_store, VectorStore): @@ -613,19 +615,18 @@ async def _adelete( async def aindex( - docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]], + docs_source: BaseLoader | Iterable[Document] | AsyncIterator[Document], record_manager: RecordManager, - vector_store: Union[VectorStore, DocumentIndex], + vector_store: VectorStore | DocumentIndex, *, batch_size: int = 100, - cleanup: Optional[Literal["incremental", "full", "scoped_full"]] = None, - source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup: Literal["incremental", "full", "scoped_full"] | None = None, + source_id_key: str | Callable[[Document], str] | None = None, cleanup_batch_size: int = 1_000, force_update: bool = False, - key_encoder: Union[ - Literal["sha1", "sha256", "sha512", "blake2b"], Callable[[Document], str] - ] = "sha1", - upsert_kwargs: Optional[dict[str, Any]] = None, + key_encoder: Literal["sha1", "sha256", "sha512", "blake2b"] + | Callable[[Document], str] = "sha1", + upsert_kwargs: dict[str, Any] | None = None, ) -> IndexingResult: """Async index data from the loader into the vector store. @@ -815,13 +816,13 @@ async def aindex( # Count documents removed by within-batch deduplication num_skipped += original_batch_size - len(hashed_docs) - source_ids: Sequence[Optional[str]] = [ + source_ids: Sequence[str | None] = [ source_id_assigner(doc) for doc in hashed_docs ] if cleanup in {"incremental", "scoped_full"}: # If the cleanup mode is incremental, source ids are required. - for source_id, hashed_doc in zip(source_ids, hashed_docs): + for source_id, hashed_doc in zip(source_ids, hashed_docs, strict=False): if source_id is None: msg = ( f"Source ids are required when cleanup mode is " @@ -845,7 +846,7 @@ async def aindex( docs_to_index: list[Document] = [] uids_to_refresh = [] seen_docs: set[str] = set() - for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch, strict=False): hashed_id = cast("str", hashed_doc.id) if doc_exists: if force_update: @@ -917,7 +918,7 @@ async def aindex( if cleanup == "full" or ( cleanup == "scoped_full" and scoped_full_cleanup_source_ids ): - delete_group_ids: Optional[Sequence[str]] = None + delete_group_ids: Sequence[str] | None = None if cleanup == "scoped_full": delete_group_ids = list(scoped_full_cleanup_source_ids) while uids_to_delete := await record_manager.alist_keys( diff --git a/libs/core/langchain_core/indexing/base.py b/libs/core/langchain_core/indexing/base.py index de8156a6f41..8fd56170e5d 100644 --- a/libs/core/langchain_core/indexing/base.py +++ b/libs/core/langchain_core/indexing/base.py @@ -5,7 +5,7 @@ from __future__ import annotations import abc import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict from typing_extensions import override @@ -100,8 +100,8 @@ class RecordManager(ABC): self, keys: Sequence[str], *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, + group_ids: Sequence[str | None] | None = None, + time_at_least: float | None = None, ) -> None: """Upsert records into the database. @@ -128,8 +128,8 @@ class RecordManager(ABC): self, keys: Sequence[str], *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, + group_ids: Sequence[str | None] | None = None, + time_at_least: float | None = None, ) -> None: """Asynchronously upsert records into the database. @@ -177,10 +177,10 @@ class RecordManager(ABC): def list_keys( self, *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, + before: float | None = None, + after: float | None = None, + group_ids: Sequence[str] | None = None, + limit: int | None = None, ) -> list[str]: """List records in the database based on the provided filters. @@ -198,10 +198,10 @@ class RecordManager(ABC): async def alist_keys( self, *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, + before: float | None = None, + after: float | None = None, + group_ids: Sequence[str] | None = None, + limit: int | None = None, ) -> list[str]: """Asynchronously list records in the database based on the provided filters. @@ -233,7 +233,7 @@ class RecordManager(ABC): class _Record(TypedDict): - group_id: Optional[str] + group_id: str | None updated_at: float @@ -270,8 +270,8 @@ class InMemoryRecordManager(RecordManager): self, keys: Sequence[str], *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, + group_ids: Sequence[str | None] | None = None, + time_at_least: float | None = None, ) -> None: """Upsert records into the database. @@ -307,8 +307,8 @@ class InMemoryRecordManager(RecordManager): self, keys: Sequence[str], *, - group_ids: Optional[Sequence[Optional[str]]] = None, - time_at_least: Optional[float] = None, + group_ids: Sequence[str | None] | None = None, + time_at_least: float | None = None, ) -> None: """Async upsert records into the database. @@ -352,10 +352,10 @@ class InMemoryRecordManager(RecordManager): def list_keys( self, *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, + before: float | None = None, + after: float | None = None, + group_ids: Sequence[str] | None = None, + limit: int | None = None, ) -> list[str]: """List records in the database based on the provided filters. @@ -388,10 +388,10 @@ class InMemoryRecordManager(RecordManager): async def alist_keys( self, *, - before: Optional[float] = None, - after: Optional[float] = None, - group_ids: Optional[Sequence[str]] = None, - limit: Optional[int] = None, + before: float | None = None, + after: float | None = None, + group_ids: Sequence[str] | None = None, + limit: int | None = None, ) -> list[str]: """Async list records in the database based on the provided filters. @@ -564,7 +564,7 @@ class DocumentIndex(BaseRetriever): ) @abc.abstractmethod - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> DeleteResponse: """Delete by IDs or other criteria. Calling delete without any input parameters should raise a ValueError! @@ -581,7 +581,7 @@ class DocumentIndex(BaseRetriever): """ async def adelete( - self, ids: Optional[list[str]] = None, **kwargs: Any + self, ids: list[str] | None = None, **kwargs: Any ) -> DeleteResponse: """Delete by IDs or other criteria. Async variant. diff --git a/libs/core/langchain_core/indexing/in_memory.py b/libs/core/langchain_core/indexing/in_memory.py index 63585968b98..1ef40f9947e 100644 --- a/libs/core/langchain_core/indexing/in_memory.py +++ b/libs/core/langchain_core/indexing/in_memory.py @@ -3,7 +3,7 @@ import operator import uuid from collections.abc import Sequence -from typing import Any, Optional, cast +from typing import Any, cast from pydantic import Field from typing_extensions import override @@ -60,7 +60,7 @@ class InMemoryDocumentIndex(DocumentIndex): return UpsertResponse(succeeded=ok_ids, failed=[]) @override - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> DeleteResponse: """Delete by IDs. Args: diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index a6fc1ff8c52..d7e7ebffc4b 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -3,10 +3,8 @@ from collections.abc import Sequence from typing import ( TYPE_CHECKING, Literal, - Optional, TypedDict, TypeVar, - Union, ) if TYPE_CHECKING: @@ -17,7 +15,7 @@ from langchain_core.messages.content import ( def is_openai_data_block( - block: dict, filter_: Union[Literal["image", "audio", "file"], None] = None + block: dict, filter_: Literal["image", "audio", "file"] | None = None ) -> bool: """Check whether a block contains multimodal data in OpenAI Chat Completions format. @@ -88,7 +86,7 @@ class ParsedDataUri(TypedDict): mime_type: str -def _parse_data_uri(uri: str) -> Optional[ParsedDataUri]: +def _parse_data_uri(uri: str) -> ParsedDataUri | None: """Parse a data URI into its components. If parsing fails, return None. If either MIME type or data is missing, return None. @@ -304,7 +302,7 @@ def _ensure_message_copy(message: T, formatted_message: T) -> T: def _update_content_block( - formatted_message: "BaseMessage", idx: int, new_block: Union[ContentBlock, dict] + formatted_message: "BaseMessage", idx: int, new_block: ContentBlock | dict ) -> None: """Update a content block at the given index, handling type issues.""" # Type ignore needed because: diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 2321af2f05d..460716c1a43 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -4,17 +4,14 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from functools import cache from typing import ( TYPE_CHECKING, Any, - Callable, Literal, - Optional, TypeAlias, TypeVar, - Union, ) from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -57,11 +54,11 @@ class LangSmithParams(TypedDict, total=False): """Name of the model.""" ls_model_type: Literal["chat", "llm"] """Type of the model. Should be 'chat' or 'llm'.""" - ls_temperature: Optional[float] + ls_temperature: float | None """Temperature for generation.""" - ls_max_tokens: Optional[int] + ls_max_tokens: int | None """Max tokens for generation.""" - ls_stop: Optional[list[str]] + ls_stop: list[str] | None """Stop words for generation.""" @@ -98,8 +95,8 @@ def _get_token_ids_default_method(text: str) -> list[int]: return tokenizer.encode(text) -LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]] -LanguageModelOutput = Union[BaseMessage, str] +LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation] +LanguageModelOutput = BaseMessage | str LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", AIMessage, str) @@ -117,7 +114,7 @@ class BaseLanguageModel( """ - cache: Union[BaseCache, bool, None] = Field(default=None, exclude=True) + cache: BaseCache | bool | None = Field(default=None, exclude=True) """Whether to cache the response. * If true, will use the global cache. @@ -132,11 +129,11 @@ class BaseLanguageModel( """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to add to the run trace.""" - tags: Optional[list[str]] = Field(default=None, exclude=True) + tags: list[str] | None = Field(default=None, exclude=True) """Tags to add to the run trace.""" - metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True) + metadata: dict[str, Any] | None = Field(default=None, exclude=True) """Metadata to add to the run trace.""" - custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field( + custom_get_token_ids: Callable[[str], list[int]] | None = Field( default=None, exclude=True ) """Optional encoder to use for counting tokens.""" @@ -146,7 +143,7 @@ class BaseLanguageModel( ) @field_validator("verbose", mode="before") - def set_verbose(cls, verbose: Optional[bool]) -> bool: # noqa: FBT001 + def set_verbose(cls, verbose: bool | None) -> bool: # noqa: FBT001 """If verbose is None, set it. This allows users to pass in None as verbose to access the global setting. @@ -169,17 +166,13 @@ class BaseLanguageModel( # This is a version of LanguageModelInput which replaces the abstract # base class BaseMessage with a union of its subclasses, which makes # for a much better schema. - return Union[ - str, - Union[StringPromptValue, ChatPromptValueConcrete], - list[AnyMessage], - ] + return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage] @abstractmethod def generate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -216,7 +209,7 @@ class BaseLanguageModel( async def agenerate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -250,8 +243,8 @@ class BaseLanguageModel( """ def with_structured_output( - self, schema: Union[dict, type], **kwargs: Any - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + self, schema: dict | type, **kwargs: Any + ) -> Runnable[LanguageModelInput, dict | BaseModel]: """Not implemented on this class.""" # Implement this on child class if there is a way of steering the model to # generate responses that match a given schema. @@ -294,7 +287,7 @@ class BaseLanguageModel( def get_num_tokens_from_messages( self, messages: list[BaseMessage], - tools: Optional[Sequence] = None, + tools: Sequence | None = None, ) -> int: """Get the number of tokens in the messages. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 070d71f8482..a4acd7683d7 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -7,10 +7,10 @@ import inspect import json import typing from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from functools import cached_property from operator import itemgetter -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic import BaseModel, ConfigDict, Field from typing_extensions import override @@ -219,7 +219,7 @@ async def agenerate_from_stream( return await run_in_executor(None, generate_from_stream, iter(chunks)) -def _format_ls_structured_output(ls_structured_output_format: Optional[dict]) -> dict: +def _format_ls_structured_output(ls_structured_output_format: dict | None) -> dict: if ls_structured_output_format: try: ls_structured_output_format_dict = { @@ -315,10 +315,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): """ # noqa: E501 - rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) + rate_limiter: BaseRateLimiter | None = Field(default=None, exclude=True) "An optional rate limiter to use for limiting the number of requests." - disable_streaming: Union[bool, Literal["tool_calling"]] = False + disable_streaming: bool | Literal["tool_calling"] = False """Whether to disable streaming for this model. If streaming is bypassed, then ``stream()``/``astream()``/``astream_events()`` will @@ -337,7 +337,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): """ - output_version: Optional[str] = Field( + output_version: str | None = Field( default_factory=from_env("LC_OUTPUT_VERSION", default=None) ) """Version of ``AIMessage`` output format to store in message content. @@ -392,9 +392,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def invoke( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AIMessage: config = ensure_config(config) @@ -419,9 +419,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def ainvoke( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AIMessage: config = ensure_config(config) @@ -443,9 +443,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): self, *, async_api: bool, - run_manager: Optional[ - Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] - ] = None, + run_manager: CallbackManagerForLLMRun + | AsyncCallbackManagerForLLMRun + | None = None, **kwargs: Any, ) -> bool: """Determine if a given model call should hit the streaming API.""" @@ -478,9 +478,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def stream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> Iterator[AIMessageChunk]: if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): @@ -556,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): and isinstance(chunk.message, AIMessageChunk) and not chunk.message.chunk_position ): - empty_content: Union[str, list] = ( + empty_content: str | list = ( "" if isinstance(chunk.message.content, str) else [] ) msg_chunk = AIMessageChunk( @@ -594,9 +594,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def astream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[AIMessageChunk]: if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): @@ -677,7 +677,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): and isinstance(chunk.message, AIMessageChunk) and not chunk.message.chunk_position ): - empty_content: Union[str, list] = ( + empty_content: str | list = ( "" if isinstance(chunk.message.content, str) else [] ) msg_chunk = AIMessageChunk( @@ -712,7 +712,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): # --- Custom methods --- - def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002 + def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict: # noqa: ARG002 return {} def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]: @@ -756,7 +756,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _get_invocation_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> dict: params = self.dict() @@ -765,7 +765,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -803,7 +803,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): return ls_params - def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: + def _get_llm_string(self, stop: list[str] | None = None, **kwargs: Any) -> str: if self.is_lc_serializable(): params = {**kwargs, "stop": stop} param_string = str(sorted(params.items())) @@ -820,13 +820,13 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def generate( self, messages: list[list[BaseMessage]], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, **kwargs: Any, ) -> LLMResult: """Pass a sequence of prompts to the model and return model generations. @@ -927,7 +927,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): output = LLMResult(generations=generations, llm_output=llm_output) if run_managers: run_infos = [] - for manager, flattened_output in zip(run_managers, flattened_outputs): + for manager, flattened_output in zip( + run_managers, flattened_outputs, strict=False + ): manager.on_llm_end(flattened_output) run_infos.append(RunInfo(run_id=manager.run_id)) output.run = run_infos @@ -936,13 +938,13 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def agenerate( self, messages: list[list[BaseMessage]], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, **kwargs: Any, ) -> LLMResult: """Asynchronously pass a sequence of prompts to a model and return generations. @@ -1049,7 +1051,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): llm_output=res.llm_output, # type: ignore[union-attr] ) ) - for run_manager, res in zip(run_managers, results) + for run_manager, res in zip(run_managers, results, strict=False) if not isinstance(res, Exception) ] ) @@ -1065,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): *[ run_manager.on_llm_end(flattened_output) for run_manager, flattened_output in zip( - run_managers, flattened_outputs + run_managers, flattened_outputs, strict=False ) ] ) @@ -1079,7 +1081,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def generate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -1090,7 +1092,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def agenerate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -1102,8 +1104,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _generate_with_cache( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache() @@ -1139,7 +1141,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): **kwargs, ): chunks: list[ChatGenerationChunk] = [] - run_id: Optional[str] = ( + run_id: str | None = ( f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None ) yielded = False @@ -1165,7 +1167,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): and isinstance(chunk.message, AIMessageChunk) and not chunk.message.chunk_position ): - empty_content: Union[str, list] = ( + empty_content: str | list = ( "" if isinstance(chunk.message.content, str) else [] ) chunk = ChatGenerationChunk( @@ -1210,8 +1212,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def _agenerate_with_cache( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache() @@ -1247,7 +1249,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): **kwargs, ): chunks: list[ChatGenerationChunk] = [] - run_id: Optional[str] = ( + run_id: str | None = ( f"{LC_ID_PREFIX}-{run_manager.run_id}" if run_manager else None ) yielded = False @@ -1273,7 +1275,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): and isinstance(chunk.message, AIMessageChunk) and not chunk.message.chunk_position ): - empty_content: Union[str, list] = ( + empty_content: str | list = ( "" if isinstance(chunk.message.content, str) else [] ) chunk = ChatGenerationChunk( @@ -1319,8 +1321,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Generate the result. @@ -1338,8 +1340,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Generate the result. @@ -1365,8 +1367,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model. @@ -1385,8 +1387,8 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Stream the output of the model. @@ -1423,7 +1425,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): async def _call_async( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, **kwargs: Any, ) -> BaseMessage: @@ -1451,10 +1453,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def bind_tools( self, tools: Sequence[ - Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 + typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006 ], *, - tool_choice: Optional[Union[str]] = None, + tool_choice: str | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: """Bind tools to the model. @@ -1471,11 +1473,11 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): def with_structured_output( self, - schema: Union[typing.Dict, type], # noqa: UP006 + schema: typing.Dict | type, # noqa: UP006 *, include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]: # noqa: UP006 + ) -> Runnable[LanguageModelInput, typing.Dict | BaseModel]: # noqa: UP006 """Model wrapper that returns outputs formatted to match the given schema. Args: @@ -1653,8 +1655,8 @@ class SimpleChatModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) @@ -1666,8 +1668,8 @@ class SimpleChatModel(BaseChatModel): def _call( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Simpler interface.""" @@ -1675,8 +1677,8 @@ class SimpleChatModel(BaseChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: return await run_in_executor( @@ -1690,7 +1692,7 @@ class SimpleChatModel(BaseChatModel): def _gen_info_and_msg_metadata( - generation: Union[ChatGeneration, ChatGenerationChunk], + generation: ChatGeneration | ChatGenerationChunk, ) -> dict: return { **(generation.generation_info or {}), diff --git a/libs/core/langchain_core/language_models/fake.py b/libs/core/langchain_core/language_models/fake.py index 72366302ba3..77b7cdd4ac6 100644 --- a/libs/core/langchain_core/language_models/fake.py +++ b/libs/core/langchain_core/language_models/fake.py @@ -3,7 +3,7 @@ import asyncio import time from collections.abc import AsyncIterator, Iterator, Mapping -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -23,7 +23,7 @@ class FakeListLLM(LLM): """List of responses to return in order.""" # This parameter should be removed from FakeListLLM since # it's only used by sub-classes. - sleep: Optional[float] = None + sleep: float | None = None """Sleep time in seconds between responses. Ignored by FakeListLLM, but used by sub-classes. @@ -44,8 +44,8 @@ class FakeListLLM(LLM): def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Return next response.""" @@ -60,8 +60,8 @@ class FakeListLLM(LLM): async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Return next response.""" @@ -91,16 +91,16 @@ class FakeStreamingListLLM(FakeListLLM): chunks in a streaming implementation. """ - error_on_chunk_number: Optional[int] = None + error_on_chunk_number: int | None = None """If set, will raise an exception on the specified chunk number.""" @override def stream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> Iterator[str]: result = self.invoke(input, config) @@ -119,9 +119,9 @@ class FakeStreamingListLLM(FakeListLLM): async def astream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[str]: result = await self.ainvoke(input, config) diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 2479129f1da..459367eb825 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -4,7 +4,7 @@ import asyncio import re import time from collections.abc import AsyncIterator, Iterator -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, cast from typing_extensions import override @@ -23,7 +23,7 @@ class FakeMessagesListChatModel(BaseChatModel): responses: list[BaseMessage] """List of responses to **cycle** through in order.""" - sleep: Optional[float] = None + sleep: float | None = None """Sleep time in seconds between responses.""" i: int = 0 """Internally incremented after every model invocation.""" @@ -32,8 +32,8 @@ class FakeMessagesListChatModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.sleep is not None: @@ -61,10 +61,10 @@ class FakeListChatModel(SimpleChatModel): responses: list[str] """List of responses to **cycle** through in order.""" - sleep: Optional[float] = None + sleep: float | None = None i: int = 0 """Internally incremented after every model invocation.""" - error_on_chunk_number: Optional[int] = None + error_on_chunk_number: int | None = None """If set, raise an error on the specified chunk number during streaming.""" @property @@ -95,8 +95,8 @@ class FakeListChatModel(SimpleChatModel): def _stream( self, messages: list[BaseMessage], - stop: Union[list[str], None] = None, - run_manager: Union[CallbackManagerForLLMRun, None] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: response = self.responses[self.i] @@ -113,7 +113,7 @@ class FakeListChatModel(SimpleChatModel): ): raise FakeListChatModelError - chunk_position: Optional[Literal["last"]] = ( + chunk_position: Literal["last"] | None = ( "last" if i_c == len(response) - 1 else None ) yield ChatGenerationChunk( @@ -124,8 +124,8 @@ class FakeListChatModel(SimpleChatModel): async def _astream( self, messages: list[BaseMessage], - stop: Union[list[str], None] = None, - run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: response = self.responses[self.i] @@ -141,7 +141,7 @@ class FakeListChatModel(SimpleChatModel): and i_c == self.error_on_chunk_number ): raise FakeListChatModelError - chunk_position: Optional[Literal["last"]] = ( + chunk_position: Literal["last"] | None = ( "last" if i_c == len(response) - 1 else None ) yield ChatGenerationChunk( @@ -158,27 +158,33 @@ class FakeListChatModel(SimpleChatModel): def batch( self, inputs: list[Any], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[AIMessage]: if isinstance(config, list): - return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] + return [ + self.invoke(m, c, **kwargs) + for m, c in zip(inputs, config, strict=False) + ] return [self.invoke(m, config, **kwargs) for m in inputs] @override async def abatch( self, inputs: list[Any], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> list[AIMessage]: if isinstance(config, list): # do Not use an async iterator here because need explicit ordering - return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)] + return [ + await self.ainvoke(m, c, **kwargs) + for m, c in zip(inputs, config, strict=False) + ] # do Not use an async iterator here because need explicit ordering return [await self.ainvoke(m, config, **kwargs) for m in inputs] @@ -190,8 +196,8 @@ class FakeChatModel(SimpleChatModel): def _call( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: return "fake response" @@ -200,8 +206,8 @@ class FakeChatModel(SimpleChatModel): async def _agenerate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: output_str = "fake response" @@ -229,7 +235,7 @@ class GenericFakeChatModel(BaseChatModel): """ - messages: Iterator[Union[AIMessage, str]] + messages: Iterator[AIMessage | str] """Get an iterator over messages. This can be expanded to accept other types like Callables / dicts / strings @@ -248,8 +254,8 @@ class GenericFakeChatModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: message = next(self.messages) @@ -260,8 +266,8 @@ class GenericFakeChatModel(BaseChatModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: chat_result = self._generate( @@ -376,8 +382,8 @@ class ParrotFakeChatModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=messages[-1])]) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index a98dac6dedf..26ca47b74a0 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -8,14 +8,11 @@ import inspect import json import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, cast, ) @@ -71,9 +68,7 @@ def _log_error_once(msg: str) -> None: def create_base_retry_decorator( error_types: list[type[BaseException]], max_retries: int = 1, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, + run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None, ) -> Callable[[Any], Any]: """Create a retry decorator for a given LLM and provided a list of error types. @@ -124,9 +119,9 @@ def create_base_retry_decorator( ) -def _resolve_cache(*, cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]: +def _resolve_cache(*, cache: BaseCache | bool | None) -> BaseCache | None: """Resolve the cache.""" - llm_cache: Optional[BaseCache] + llm_cache: BaseCache | None if isinstance(cache, BaseCache): llm_cache = cache elif cache is None: @@ -151,7 +146,7 @@ def _resolve_cache(*, cache: Union[BaseCache, bool, None]) -> Optional[BaseCache def get_prompts( params: dict[str, Any], prompts: list[str], - cache: Union[BaseCache, bool, None] = None, # noqa: FBT001 + cache: BaseCache | bool | None = None, # noqa: FBT001 ) -> tuple[dict[int, list], str, list[int], list[str]]: """Get prompts that are already cached. @@ -187,7 +182,7 @@ def get_prompts( async def aget_prompts( params: dict[str, Any], prompts: list[str], - cache: Union[BaseCache, bool, None] = None, # noqa: FBT001 + cache: BaseCache | bool | None = None, # noqa: FBT001 ) -> tuple[dict[int, list], str, list[int], list[str]]: """Get prompts that are already cached. Async version. @@ -220,13 +215,13 @@ async def aget_prompts( def update_cache( - cache: Union[BaseCache, bool, None], # noqa: FBT001 + cache: BaseCache | bool | None, # noqa: FBT001 existing_prompts: dict[int, list], llm_string: str, missing_prompt_idxs: list[int], new_results: LLMResult, prompts: list[str], -) -> Optional[dict]: +) -> dict | None: """Update the cache and get the LLM output. Args: @@ -253,13 +248,13 @@ def update_cache( async def aupdate_cache( - cache: Union[BaseCache, bool, None], # noqa: FBT001 + cache: BaseCache | bool | None, # noqa: FBT001 existing_prompts: dict[int, list], llm_string: str, missing_prompt_idxs: list[int], new_results: LLMResult, prompts: list[str], -) -> Optional[dict]: +) -> dict | None: """Update the cache and get the LLM output. Async version. Args: @@ -322,7 +317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _get_ls_params( self, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -361,9 +356,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): def invoke( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> str: config = ensure_config(config) @@ -386,9 +381,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def ainvoke( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> str: config = ensure_config(config) @@ -408,7 +403,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def batch( self, inputs: list[LanguageModelInput], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -455,7 +450,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def abatch( self, inputs: list[LanguageModelInput], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -501,9 +496,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): def stream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> Iterator[str]: if type(self)._stream == BaseLLM._stream: # noqa: SLF001 @@ -538,7 +533,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - generation: Optional[GenerationChunk] = None + generation: GenerationChunk | None = None try: for chunk in self._stream( prompt, stop=stop, run_manager=run_manager, **kwargs @@ -568,9 +563,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def astream( self, input: LanguageModelInput, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, **kwargs: Any, ) -> AsyncIterator[str]: if ( @@ -608,7 +603,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=config.pop("run_id", None), batch_size=1, ) - generation: Optional[GenerationChunk] = None + generation: GenerationChunk | None = None try: async for chunk in self._astream( prompt, @@ -641,8 +636,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts. @@ -661,8 +656,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _agenerate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts. @@ -689,8 +684,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Stream the LLM on the given prompt. @@ -717,8 +712,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: """An async version of the _stream method. @@ -762,8 +757,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): def generate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, - callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, + stop: list[str] | None = None, + callbacks: Callbacks | list[Callbacks] | None = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -773,8 +768,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def agenerate_prompt( self, prompts: list[PromptValue], - stop: Optional[list[str]] = None, - callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, + stop: list[str] | None = None, + callbacks: Callbacks | list[Callbacks] | None = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -785,7 +780,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _generate_helper( self, prompts: list[str], - stop: Optional[list[str]], + stop: list[str] | None, run_managers: list[CallbackManagerForLLMRun], *, new_arg_supported: bool, @@ -808,7 +803,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_manager.on_llm_error(e, response=LLMResult(generations=[])) raise flattened_outputs = output.flatten() - for manager, flattened_output in zip(run_managers, flattened_outputs): + for manager, flattened_output in zip( + run_managers, flattened_outputs, strict=False + ): manager.on_llm_end(flattened_output) if run_managers: output.run = [ @@ -819,13 +816,13 @@ class BaseLLM(BaseLanguageModel[str], ABC): def generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, + stop: list[str] | None = None, + callbacks: Callbacks | list[Callbacks] | None = None, *, - tags: Optional[Union[list[str], list[list[str]]]] = None, - metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, - run_name: Optional[Union[str, list[str]]] = None, - run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None, + tags: list[str] | list[list[str]] | None = None, + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + run_name: str | list[str] | None = None, + run_id: uuid.UUID | list[uuid.UUID | None] | None = None, **kwargs: Any, ) -> LLMResult: """Pass a sequence of prompts to a model and return generations. @@ -915,14 +912,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): msg = "run_name must be a list of the same length as prompts" raise ValueError(msg) callbacks = cast("list[Callbacks]", callbacks) - tags_list = cast( - "list[Optional[list[str]]]", tags or ([None] * len(prompts)) - ) + tags_list = cast("list[list[str] | None]", tags or ([None] * len(prompts))) metadata_list = cast( - "list[Optional[dict[str, Any]]]", metadata or ([{}] * len(prompts)) + "list[dict[str, Any] | None]", metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( - "list[Optional[str]]", ([None] * len(prompts)) + "list[str | None]", ([None] * len(prompts)) ) callback_managers = [ CallbackManager.configure( @@ -934,7 +929,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): meta, self.metadata, ) - for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + for callback, tag, meta in zip( + callbacks, tags_list, metadata_list, strict=False + ) ] else: # We've received a single callbacks arg to apply to all inputs @@ -949,7 +946,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) ] * len(prompts) - run_name_list = [cast("Optional[str]", run_name)] * len(prompts) + run_name_list = [cast("str | None", run_name)] * len(prompts) run_ids_list = self._get_run_ids_list(run_id, prompts) params = self.dict() params["stop"] = stop @@ -975,7 +972,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=run_id_, )[0] for callback_manager, prompt, run_name, run_id_ in zip( - callback_managers, prompts, run_name_list, run_ids_list + callback_managers, + prompts, + run_name_list, + run_ids_list, + strict=False, ) ] return self._generate_helper( @@ -1025,7 +1026,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): @staticmethod def _get_run_ids_list( - run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list + run_id: uuid.UUID | list[uuid.UUID | None] | None, prompts: list ) -> list: if run_id is None: return [None] * len(prompts) @@ -1042,7 +1043,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _agenerate_helper( self, prompts: list[str], - stop: Optional[list[str]], + stop: list[str] | None, run_managers: list[AsyncCallbackManagerForLLMRun], *, new_arg_supported: bool, @@ -1072,7 +1073,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): *[ run_manager.on_llm_end(flattened_output) for run_manager, flattened_output in zip( - run_managers, flattened_outputs + run_managers, flattened_outputs, strict=False ) ] ) @@ -1085,13 +1086,13 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def agenerate( self, prompts: list[str], - stop: Optional[list[str]] = None, - callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, + stop: list[str] | None = None, + callbacks: Callbacks | list[Callbacks] | None = None, *, - tags: Optional[Union[list[str], list[list[str]]]] = None, - metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, - run_name: Optional[Union[str, list[str]]] = None, - run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None, + tags: list[str] | list[list[str]] | None = None, + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + run_name: str | list[str] | None = None, + run_id: uuid.UUID | list[uuid.UUID | None] | None = None, **kwargs: Any, ) -> LLMResult: """Asynchronously pass a sequence of prompts to a model and return generations. @@ -1170,14 +1171,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): msg = "run_name must be a list of the same length as prompts" raise ValueError(msg) callbacks = cast("list[Callbacks]", callbacks) - tags_list = cast( - "list[Optional[list[str]]]", tags or ([None] * len(prompts)) - ) + tags_list = cast("list[list[str] | None]", tags or ([None] * len(prompts))) metadata_list = cast( - "list[Optional[dict[str, Any]]]", metadata or ([{}] * len(prompts)) + "list[dict[str, Any] | None]", metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( - "list[Optional[str]]", ([None] * len(prompts)) + "list[str | None]", ([None] * len(prompts)) ) callback_managers = [ AsyncCallbackManager.configure( @@ -1189,7 +1188,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): meta, self.metadata, ) - for callback, tag, meta in zip(callbacks, tags_list, metadata_list) + for callback, tag, meta in zip( + callbacks, tags_list, metadata_list, strict=False + ) ] else: # We've received a single callbacks arg to apply to all inputs @@ -1204,7 +1205,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): self.metadata, ) ] * len(prompts) - run_name_list = [cast("Optional[str]", run_name)] * len(prompts) + run_name_list = [cast("str | None", run_name)] * len(prompts) run_ids_list = self._get_run_ids_list(run_id, prompts) params = self.dict() params["stop"] = stop @@ -1234,7 +1235,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_id=run_id_, ) for callback_manager, prompt, run_name, run_id_ in zip( - callback_managers, prompts, run_name_list, run_ids_list + callback_managers, + prompts, + run_name_list, + run_ids_list, + strict=False, ) ] ) @@ -1290,11 +1295,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _call_async( self, prompt: str, - stop: Optional[list[str]] = None, + stop: list[str] | None = None, callbacks: Callbacks = None, *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" @@ -1325,7 +1330,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): starter_dict["_type"] = self._llm_type return starter_dict - def save(self, file_path: Union[Path, str]) -> None: + def save(self, file_path: Path | str) -> None: """Save the LLM. Args: @@ -1395,8 +1400,8 @@ class LLM(BaseLLM): def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Run the LLM on the given input. @@ -1419,8 +1424,8 @@ class LLM(BaseLLM): async def _acall( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Async version of the _call method. @@ -1453,8 +1458,8 @@ class LLM(BaseLLM): def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: # TODO: add caching here. @@ -1472,8 +1477,8 @@ class LLM(BaseLLM): async def _agenerate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: generations = [] diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index 52f4e2080b8..c288c4e8977 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -3,7 +3,7 @@ import importlib import json import os -from typing import Any, Optional +from typing import Any from langchain_core._api import beta from langchain_core.load.mapping import ( @@ -50,12 +50,11 @@ class Reviver: def __init__( self, - secrets_map: Optional[dict[str, str]] = None, - valid_namespaces: Optional[list[str]] = None, + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, secrets_from_env: bool = True, # noqa: FBT001,FBT002 - additional_import_mappings: Optional[ - dict[tuple[str, ...], tuple[str, ...]] - ] = None, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] + | None = None, *, ignore_unserializable_fields: bool = False, ) -> None: @@ -187,10 +186,10 @@ class Reviver: def loads( text: str, *, - secrets_map: Optional[dict[str, str]] = None, - valid_namespaces: Optional[list[str]] = None, + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, secrets_from_env: bool = True, - additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, ignore_unserializable_fields: bool = False, ) -> Any: """Revive a LangChain class from a JSON string. @@ -231,10 +230,10 @@ def loads( def load( obj: Any, *, - secrets_map: Optional[dict[str, str]] = None, - valid_namespaces: Optional[list[str]] = None, + secrets_map: dict[str, str] | None = None, + valid_namespaces: list[str] | None = None, secrets_from_env: bool = True, - additional_import_mappings: Optional[dict[tuple[str, ...], tuple[str, ...]]] = None, + additional_import_mappings: dict[tuple[str, ...], tuple[str, ...]] | None = None, ignore_unserializable_fields: bool = False, ) -> Any: """Revive a LangChain class from a JSON object. diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index 9196416a71c..55c7ce927a5 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -6,9 +6,7 @@ from abc import ABC from typing import ( Any, Literal, - Optional, TypedDict, - Union, cast, ) @@ -53,7 +51,7 @@ class SerializedNotImplemented(BaseSerialized): type: Literal["not_implemented"] """The type of the object. Must be ``'not_implemented'``.""" - repr: Optional[str] + repr: str | None """The representation of the object. Optional.""" @@ -188,7 +186,7 @@ class Serializable(BaseModel, ABC): if (k not in type(self).model_fields or try_neq_default(v, k, self)) ] - def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: + def to_json(self) -> SerializedConstructor | SerializedNotImplemented: """Serialize the object to JSON. Raises: diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index b6f1597154f..c96f4586775 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -4,7 +4,7 @@ import json import logging import operator from collections.abc import Sequence -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, cast, overload from pydantic import model_validator from typing_extensions import NotRequired, Self, TypedDict, override @@ -160,7 +160,7 @@ class AIMessage(BaseMessage): """If provided, tool calls associated with the message.""" invalid_tool_calls: list[InvalidToolCall] = [] """If provided, tool calls with parsing errors associated with the message.""" - usage_metadata: Optional[UsageMetadata] = None + usage_metadata: UsageMetadata | None = None """If provided, usage metadata for a message, such as token counts. This is a standard representation of token usage that is consistent across models. @@ -173,22 +173,22 @@ class AIMessage(BaseMessage): @overload def __init__( self, - content: Union[str, list[Union[str, dict]]], + content: str | list[str | dict], **kwargs: Any, ) -> None: ... @overload def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: ... def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: """Initialize ``AIMessage``. @@ -209,7 +209,7 @@ class AIMessage(BaseMessage): kwargs["tool_calls"] = content_tool_calls super().__init__( - content=cast("Union[str, list[Union[str, dict]]]", content_blocks), + content=cast("str | list[str | dict]", content_blocks), **kwargs, ) else: @@ -344,7 +344,7 @@ class AIMessage(BaseMessage): base = super().pretty_repr(html=html) lines = [] - def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> list[str]: + def _format_tool_args(tc: ToolCall | InvalidToolCall) -> list[str]: lines = [ f" {tc.get('name', 'Tool')} ({tc.get('id')})", f" Call ID: {tc.get('id')}", @@ -387,7 +387,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): tool_call_chunks: list[ToolCallChunk] = [] """If provided, tool call chunks associated with the message.""" - chunk_position: Optional[Literal["last"]] = None + chunk_position: Literal["last"] | None = None """Optional span represented by an aggregated AIMessageChunk. If a chunk with ``chunk_position="last"`` is aggregated into a stream, @@ -632,7 +632,7 @@ def add_ai_message_chunks( # Token usage if left.usage_metadata or any(o.usage_metadata is not None for o in others): - usage_metadata: Optional[UsageMetadata] = left.usage_metadata + usage_metadata: UsageMetadata | None = left.usage_metadata for other in others: usage_metadata = add_usage(usage_metadata, other.usage_metadata) else: @@ -662,7 +662,7 @@ def add_ai_message_chunks( chunk_id = id_ break - chunk_position: Optional[Literal["last"]] = ( + chunk_position: Literal["last"] | None = ( "last" if any(x.chunk_position == "last" for x in [left, *others]) else None ) @@ -677,9 +677,7 @@ def add_ai_message_chunks( ) -def add_usage( - left: Optional[UsageMetadata], right: Optional[UsageMetadata] -) -> UsageMetadata: +def add_usage(left: UsageMetadata | None, right: UsageMetadata | None) -> UsageMetadata: """Recursively add two UsageMetadata objects. Example: @@ -740,7 +738,7 @@ def add_usage( def subtract_usage( - left: Optional[UsageMetadata], right: Optional[UsageMetadata] + left: UsageMetadata | None, right: UsageMetadata | None ) -> UsageMetadata: """Recursively subtract two ``UsageMetadata`` objects. diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 84f2b6217b4..1df5b375b47 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union, cast, overload +from typing import TYPE_CHECKING, Any, cast, overload from pydantic import ConfigDict, Field from typing_extensions import Self @@ -22,7 +22,7 @@ if TYPE_CHECKING: def _extract_reasoning_from_additional_kwargs( message: BaseMessage, -) -> Optional[types.ReasoningContentBlock]: +) -> types.ReasoningContentBlock | None: """Extract `reasoning_content` from `additional_kwargs`. Handles reasoning content stored in various formats: @@ -95,7 +95,7 @@ class BaseMessage(Serializable): Messages are the inputs and outputs of a ``ChatModel``. """ - content: Union[str, list[Union[str, dict]]] + content: str | list[str | dict] """The string contents of the message.""" additional_kwargs: dict = Field(default_factory=dict) @@ -117,7 +117,7 @@ class BaseMessage(Serializable): """ - name: Optional[str] = None + name: str | None = None """An optional name for the message. This can be used to provide a human-readable name for the message. @@ -127,7 +127,7 @@ class BaseMessage(Serializable): """ - id: Optional[str] = Field(default=None, coerce_numbers_to_str=True) + id: str | None = Field(default=None, coerce_numbers_to_str=True) """An optional unique identifier for the message. This should ideally be provided by the provider/model which created the message. @@ -141,22 +141,22 @@ class BaseMessage(Serializable): @overload def __init__( self, - content: Union[str, list[Union[str, dict]]], + content: str | list[str | dict], **kwargs: Any, ) -> None: ... @overload def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: ... def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: """Initialize ``BaseMessage``. @@ -325,9 +325,9 @@ class BaseMessage(Serializable): def merge_content( - first_content: Union[str, list[Union[str, dict]]], - *contents: Union[str, list[Union[str, dict]]], -) -> Union[str, list[Union[str, dict]]]: + first_content: str | list[str | dict], + *contents: str | list[str | dict], +) -> str | list[str | dict]: """Merge multiple message contents. Args: @@ -338,7 +338,7 @@ def merge_content( The merged content. """ - merged: Union[str, list[Union[str, dict]]] + merged: str | list[str | dict] merged = "" if first_content is None else first_content for content in contents: diff --git a/libs/core/langchain_core/messages/block_translators/__init__.py b/libs/core/langchain_core/messages/block_translators/__init__.py index 01a8cd73189..90156154e7a 100644 --- a/libs/core/langchain_core/messages/block_translators/__init__.py +++ b/libs/core/langchain_core/messages/block_translators/__init__.py @@ -12,7 +12,8 @@ the implementation in ``BaseMessage``. from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING if TYPE_CHECKING: from langchain_core.messages import AIMessage, AIMessageChunk diff --git a/libs/core/langchain_core/messages/block_translators/anthropic.py b/libs/core/langchain_core/messages/block_translators/anthropic.py index f61fc461ffb..6d1b803ba7d 100644 --- a/libs/core/langchain_core/messages/block_translators/anthropic.py +++ b/libs/core/langchain_core/messages/block_translators/anthropic.py @@ -2,7 +2,7 @@ import json from collections.abc import Iterable -from typing import Any, Optional, Union, cast +from typing import Any, cast from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.messages import content as types @@ -200,7 +200,7 @@ def _convert_citation_to_v1(citation: dict[str, Any]) -> types.Annotation: def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock]: """Convert Anthropic message content to v1 format.""" if isinstance(message.content, str): - content: list[Union[str, dict]] = [{"type": "text", "text": message.content}] + content: list[str | dict] = [{"type": "text", "text": message.content}] else: content = message.content @@ -252,7 +252,7 @@ def _convert_to_v1_from_anthropic(message: AIMessage) -> list[types.ContentBlock tool_call_chunk["type"] = "tool_call_chunk" yield tool_call_chunk else: - tool_call_block: Optional[types.ToolCall] = None + tool_call_block: types.ToolCall | None = None # Non-streaming or gathered chunk if len(message.tool_calls) == 1: tool_call_block = { diff --git a/libs/core/langchain_core/messages/block_translators/bedrock_converse.py b/libs/core/langchain_core/messages/block_translators/bedrock_converse.py index 106aa32dcbf..dfbc993db82 100644 --- a/libs/core/langchain_core/messages/block_translators/bedrock_converse.py +++ b/libs/core/langchain_core/messages/block_translators/bedrock_converse.py @@ -2,7 +2,7 @@ import base64 from collections.abc import Iterable -from typing import Any, Optional, cast +from typing import Any, cast from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.messages import content as types @@ -216,7 +216,7 @@ def _convert_to_v1_from_converse(message: AIMessage) -> list[types.ContentBlock] tool_call_chunk["type"] = "tool_call_chunk" yield tool_call_chunk else: - tool_call_block: Optional[types.ToolCall] = None + tool_call_block: types.ToolCall | None = None # Non-streaming or gathered chunk if len(message.tool_calls) == 1: tool_call_block = { diff --git a/libs/core/langchain_core/messages/block_translators/langchain_v0.py b/libs/core/langchain_core/messages/block_translators/langchain_v0.py index f2bfa632836..9056c26c375 100644 --- a/libs/core/langchain_core/messages/block_translators/langchain_v0.py +++ b/libs/core/langchain_core/messages/block_translators/langchain_v0.py @@ -1,6 +1,6 @@ """Derivations of standard content blocks from LangChain v0 multimodal content.""" -from typing import Any, Union, cast +from typing import Any, cast from langchain_core.messages import content as types @@ -45,7 +45,7 @@ def _convert_v0_multimodal_input_to_v1( def _convert_legacy_v0_content_block_to_v1( block: dict, -) -> Union[types.ContentBlock, dict]: +) -> types.ContentBlock | dict: """Convert a LangChain v0 content block to v1 format. Preserves unknown keys as extras to avoid data loss. diff --git a/libs/core/langchain_core/messages/block_translators/openai.py b/libs/core/langchain_core/messages/block_translators/openai.py index 0fe1e697b97..e74ba68922d 100644 --- a/libs/core/langchain_core/messages/block_translators/openai.py +++ b/libs/core/langchain_core/messages/block_translators/openai.py @@ -5,7 +5,7 @@ from __future__ import annotations import json import warnings from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, cast from langchain_core.language_models._utils import ( _parse_data_uri, @@ -401,7 +401,7 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage: def _convert_openai_format_to_data_block( block: dict, -) -> Union[types.ContentBlock, dict[Any, Any]]: +) -> types.ContentBlock | dict[Any, Any]: """Convert OpenAI image/audio/file content block to respective v1 multimodal block. We expect that the incoming block is verified to be in OpenAI Chat Completions @@ -677,9 +677,9 @@ def _convert_to_v1_from_responses(message: AIMessage) -> list[types.ContentBlock yield cast("types.ImageContentBlock", new_block) elif block_type == "function_call": - tool_call_block: Optional[ - Union[types.ToolCall, types.InvalidToolCall, types.ToolCallChunk] - ] = None + tool_call_block: ( + types.ToolCall | types.InvalidToolCall | types.ToolCallChunk | None + ) = None call_id = block.get("call_id", "") from langchain_core.messages import AIMessageChunk # noqa: PLC0415 @@ -726,7 +726,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> list[types.ContentBlock if "index" in block: web_search_call["index"] = f"lc_wsc_{block['index']}" - sources: Optional[dict[str, Any]] = None + sources: dict[str, Any] | None = None if "action" in block and isinstance(block["action"], dict): if "sources" in block["action"]: sources = block["action"]["sources"] diff --git a/libs/core/langchain_core/messages/content.py b/libs/core/langchain_core/messages/content.py index 5b1fde2aaf6..abe633d9ab8 100644 --- a/libs/core/langchain_core/messages/content.py +++ b/libs/core/langchain_core/messages/content.py @@ -129,7 +129,7 @@ Factory functions offer benefits such as: """ -from typing import Any, Literal, Optional, Union, get_args, get_type_hints +from typing import Any, Literal, get_args, get_type_hints from typing_extensions import NotRequired, TypedDict @@ -211,7 +211,7 @@ class NonStandardAnnotation(TypedDict): """Provider-specific annotation data.""" -Annotation = Union[Citation, NonStandardAnnotation] +Annotation = Citation | NonStandardAnnotation class TextContentBlock(TypedDict): @@ -247,7 +247,7 @@ class TextContentBlock(TypedDict): annotations: NotRequired[list[Annotation]] """``Citation``s and other annotations.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -278,7 +278,7 @@ class ToolCall(TypedDict): type: Literal["tool_call"] """Used for discrimination.""" - id: Optional[str] + id: str | None """An identifier associated with the tool call. An identifier is needed to associate a tool call request with a tool @@ -293,7 +293,7 @@ class ToolCall(TypedDict): args: dict[str, Any] """The arguments to the tool call.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -326,7 +326,7 @@ class ToolCallChunk(TypedDict): type: Literal["tool_call_chunk"] """Used for serialization.""" - id: Optional[str] + id: str | None """An identifier associated with the tool call. An identifier is needed to associate a tool call request with a tool @@ -334,13 +334,13 @@ class ToolCallChunk(TypedDict): """ - name: Optional[str] + name: str | None """The name of the tool to be called.""" - args: Optional[str] + args: str | None """The arguments to the tool call.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """The index of the tool call in a sequence.""" extras: NotRequired[dict[str, Any]] @@ -360,7 +360,7 @@ class InvalidToolCall(TypedDict): type: Literal["invalid_tool_call"] """Used for discrimination.""" - id: Optional[str] + id: str | None """An identifier associated with the tool call. An identifier is needed to associate a tool call request with a tool @@ -368,16 +368,16 @@ class InvalidToolCall(TypedDict): """ - name: Optional[str] + name: str | None """The name of the tool to be called.""" - args: Optional[str] + args: str | None """The arguments to the tool call.""" - error: Optional[str] + error: str | None """An error message associated with the tool call.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -399,7 +399,7 @@ class ServerToolCall(TypedDict): args: dict[str, Any] """The arguments to the tool call.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -421,7 +421,7 @@ class ServerToolCallChunk(TypedDict): id: NotRequired[str] """An identifier associated with the tool call.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -446,7 +446,7 @@ class ServerToolResult(TypedDict): output: NotRequired[Any] """Output of the executed tool.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -485,7 +485,7 @@ class ReasoningContentBlock(TypedDict): """ - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" extras: NotRequired[dict[str, Any]] @@ -529,7 +529,7 @@ class ImageContentBlock(TypedDict): """ - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" url: NotRequired[str] @@ -576,7 +576,7 @@ class VideoContentBlock(TypedDict): """ - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" url: NotRequired[str] @@ -622,7 +622,7 @@ class AudioContentBlock(TypedDict): """ - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" url: NotRequired[str] @@ -675,7 +675,7 @@ class PlainTextContentBlock(TypedDict): mime_type: Literal["text/plain"] """MIME type of the file. Required for base64.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" url: NotRequired[str] @@ -738,7 +738,7 @@ class FileContentBlock(TypedDict): """ - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" url: NotRequired[str] @@ -793,35 +793,31 @@ class NonStandardContentBlock(TypedDict): value: dict[str, Any] """Provider-specific data.""" - index: NotRequired[Union[int, str]] + index: NotRequired[int | str] """Index of block in aggregate response. Used during streaming.""" # --- Aliases --- -DataContentBlock = Union[ - ImageContentBlock, - VideoContentBlock, - AudioContentBlock, - PlainTextContentBlock, - FileContentBlock, -] +DataContentBlock = ( + ImageContentBlock + | VideoContentBlock + | AudioContentBlock + | PlainTextContentBlock + | FileContentBlock +) -ToolContentBlock = Union[ - ToolCall, - ToolCallChunk, - ServerToolCall, - ServerToolCallChunk, - ServerToolResult, -] +ToolContentBlock = ( + ToolCall | ToolCallChunk | ServerToolCall | ServerToolCallChunk | ServerToolResult +) -ContentBlock = Union[ - TextContentBlock, - InvalidToolCall, - ReasoningContentBlock, - NonStandardContentBlock, - DataContentBlock, - ToolContentBlock, -] +ContentBlock = ( + TextContentBlock + | InvalidToolCall + | ReasoningContentBlock + | NonStandardContentBlock + | DataContentBlock + | ToolContentBlock +) KNOWN_BLOCK_TYPES = { @@ -922,9 +918,9 @@ def is_data_content_block(block: dict) -> bool: def create_text_block( text: str, *, - id: Optional[str] = None, - annotations: Optional[list[Annotation]] = None, - index: Optional[Union[int, str]] = None, + id: str | None = None, + annotations: list[Annotation] | None = None, + index: int | str | None = None, **kwargs: Any, ) -> TextContentBlock: """Create a ``TextContentBlock``. @@ -962,12 +958,12 @@ def create_text_block( def create_image_block( *, - url: Optional[str] = None, - base64: Optional[str] = None, - file_id: Optional[str] = None, - mime_type: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> ImageContentBlock: """Create an ``ImageContentBlock``. @@ -1018,12 +1014,12 @@ def create_image_block( def create_video_block( *, - url: Optional[str] = None, - base64: Optional[str] = None, - file_id: Optional[str] = None, - mime_type: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> VideoContentBlock: """Create a ``VideoContentBlock``. @@ -1078,12 +1074,12 @@ def create_video_block( def create_audio_block( *, - url: Optional[str] = None, - base64: Optional[str] = None, - file_id: Optional[str] = None, - mime_type: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> AudioContentBlock: """Create an ``AudioContentBlock``. @@ -1138,12 +1134,12 @@ def create_audio_block( def create_file_block( *, - url: Optional[str] = None, - base64: Optional[str] = None, - file_id: Optional[str] = None, - mime_type: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + mime_type: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> FileContentBlock: """Create a ``FileContentBlock``. @@ -1197,14 +1193,14 @@ def create_file_block( def create_plaintext_block( - text: Optional[str] = None, - url: Optional[str] = None, - base64: Optional[str] = None, - file_id: Optional[str] = None, - title: Optional[str] = None, - context: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + text: str | None = None, + url: str | None = None, + base64: str | None = None, + file_id: str | None = None, + title: str | None = None, + context: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> PlainTextContentBlock: """Create a ``PlainTextContentBlock``. @@ -1259,8 +1255,8 @@ def create_tool_call( name: str, args: dict[str, Any], *, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> ToolCall: """Create a ``ToolCall``. @@ -1297,9 +1293,9 @@ def create_tool_call( def create_reasoning_block( - reasoning: Optional[str] = None, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + reasoning: str | None = None, + id: str | None = None, + index: int | str | None = None, **kwargs: Any, ) -> ReasoningContentBlock: """Create a ``ReasoningContentBlock``. @@ -1335,12 +1331,12 @@ def create_reasoning_block( def create_citation( *, - url: Optional[str] = None, - title: Optional[str] = None, - start_index: Optional[int] = None, - end_index: Optional[int] = None, - cited_text: Optional[str] = None, - id: Optional[str] = None, + url: str | None = None, + title: str | None = None, + start_index: int | None = None, + end_index: int | None = None, + cited_text: str | None = None, + id: str | None = None, **kwargs: Any, ) -> Citation: """Create a ``Citation``. @@ -1384,8 +1380,8 @@ def create_citation( def create_non_standard_block( value: dict[str, Any], *, - id: Optional[str] = None, - index: Optional[Union[int, str]] = None, + id: str | None = None, + index: int | str | None = None, ) -> NonStandardContentBlock: """Create a ``NonStandardContentBlock``. diff --git a/libs/core/langchain_core/messages/human.py b/libs/core/langchain_core/messages/human.py index 8af282b2a95..aa7b71796d6 100644 --- a/libs/core/langchain_core/messages/human.py +++ b/libs/core/langchain_core/messages/human.py @@ -1,6 +1,6 @@ """Human message.""" -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, cast, overload from langchain_core.messages import content as types from langchain_core.messages.base import BaseMessage, BaseMessageChunk @@ -38,28 +38,28 @@ class HumanMessage(BaseMessage): @overload def __init__( self, - content: Union[str, list[Union[str, dict]]], + content: str | list[str | dict], **kwargs: Any, ) -> None: ... @overload def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: ... def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: """Specify ``content`` as positional arg or ``content_blocks`` for typing.""" if content_blocks is not None: super().__init__( - content=cast("Union[str, list[Union[str, dict]]]", content_blocks), + content=cast("str | list[str | dict]", content_blocks), **kwargs, ) else: diff --git a/libs/core/langchain_core/messages/system.py b/libs/core/langchain_core/messages/system.py index 2c214189621..54c66a04005 100644 --- a/libs/core/langchain_core/messages/system.py +++ b/libs/core/langchain_core/messages/system.py @@ -1,6 +1,6 @@ """System message.""" -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, cast, overload from langchain_core.messages import content as types from langchain_core.messages.base import BaseMessage, BaseMessageChunk @@ -38,28 +38,28 @@ class SystemMessage(BaseMessage): @overload def __init__( self, - content: Union[str, list[Union[str, dict]]], + content: str | list[str | dict], **kwargs: Any, ) -> None: ... @overload def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: ... def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: """Specify ``content`` as positional arg or ``content_blocks`` for typing.""" if content_blocks is not None: super().__init__( - content=cast("Union[str, list[Union[str, dict]]]", content_blocks), + content=cast("str | list[str | dict]", content_blocks), **kwargs, ) else: diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 0a400f372f0..5d2ef8fe88b 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -1,7 +1,7 @@ """Messages for tools.""" import json -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, cast, overload from uuid import UUID from pydantic import Field, model_validator @@ -147,22 +147,22 @@ class ToolMessage(BaseMessage, ToolOutputMixin): @overload def __init__( self, - content: Union[str, list[Union[str, dict]]], + content: str | list[str | dict], **kwargs: Any, ) -> None: ... @overload def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: ... def __init__( self, - content: Optional[Union[str, list[Union[str, dict]]]] = None, - content_blocks: Optional[list[types.ContentBlock]] = None, + content: str | list[str | dict] | None = None, + content_blocks: list[types.ContentBlock] | None = None, **kwargs: Any, ) -> None: """Initialize ``ToolMessage``. @@ -176,7 +176,7 @@ class ToolMessage(BaseMessage, ToolOutputMixin): """ if content_blocks is not None: super().__init__( - content=cast("Union[str, list[Union[str, dict]]]", content_blocks), + content=cast("str | list[str | dict]", content_blocks), **kwargs, ) else: @@ -233,7 +233,7 @@ class ToolCall(TypedDict): """The name of the tool to be called.""" args: dict[str, Any] """The arguments to the tool call.""" - id: Optional[str] + id: str | None """An identifier associated with the tool call. An identifier is needed to associate a tool call request with a tool @@ -247,7 +247,7 @@ def tool_call( *, name: str, args: dict[str, Any], - id: Optional[str], + id: str | None, ) -> ToolCall: """Create a tool call. @@ -283,23 +283,23 @@ class ToolCallChunk(TypedDict): """ - name: Optional[str] + name: str | None """The name of the tool to be called.""" - args: Optional[str] + args: str | None """The arguments to the tool call.""" - id: Optional[str] + id: str | None """An identifier associated with the tool call.""" - index: Optional[int] + index: int | None """The index of the tool call in a sequence.""" type: NotRequired[Literal["tool_call_chunk"]] def tool_call_chunk( *, - name: Optional[str] = None, - args: Optional[str] = None, - id: Optional[str] = None, - index: Optional[int] = None, + name: str | None = None, + args: str | None = None, + id: str | None = None, + index: int | None = None, ) -> ToolCallChunk: """Create a tool call chunk. @@ -319,10 +319,10 @@ def tool_call_chunk( def invalid_tool_call( *, - name: Optional[str] = None, - args: Optional[str] = None, - id: Optional[str] = None, - error: Optional[str] = None, + name: str | None = None, + args: str | None = None, + id: str | None = None, + error: str | None = None, ) -> InvalidToolCall: """Create an invalid tool call. diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 7e2bd238c1b..422f6af3eca 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -15,16 +15,13 @@ import inspect import json import logging import math -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial from typing import ( TYPE_CHECKING, Annotated, Any, - Callable, Literal, - Optional, - Union, cast, overload, ) @@ -76,20 +73,18 @@ def _get_type(v: Any) -> str: AnyMessage = Annotated[ - Union[ - Annotated[AIMessage, Tag(tag="ai")], - Annotated[HumanMessage, Tag(tag="human")], - Annotated[ChatMessage, Tag(tag="chat")], - Annotated[SystemMessage, Tag(tag="system")], - Annotated[FunctionMessage, Tag(tag="function")], - Annotated[ToolMessage, Tag(tag="tool")], - Annotated[AIMessageChunk, Tag(tag="AIMessageChunk")], - Annotated[HumanMessageChunk, Tag(tag="HumanMessageChunk")], - Annotated[ChatMessageChunk, Tag(tag="ChatMessageChunk")], - Annotated[SystemMessageChunk, Tag(tag="SystemMessageChunk")], - Annotated[FunctionMessageChunk, Tag(tag="FunctionMessageChunk")], - Annotated[ToolMessageChunk, Tag(tag="ToolMessageChunk")], - ], + Annotated[AIMessage, Tag(tag="ai")] + | Annotated[HumanMessage, Tag(tag="human")] + | Annotated[ChatMessage, Tag(tag="chat")] + | Annotated[SystemMessage, Tag(tag="system")] + | Annotated[FunctionMessage, Tag(tag="function")] + | Annotated[ToolMessage, Tag(tag="tool")] + | Annotated[AIMessageChunk, Tag(tag="AIMessageChunk")] + | Annotated[HumanMessageChunk, Tag(tag="HumanMessageChunk")] + | Annotated[ChatMessageChunk, Tag(tag="ChatMessageChunk")] + | Annotated[SystemMessageChunk, Tag(tag="SystemMessageChunk")] + | Annotated[FunctionMessageChunk, Tag(tag="FunctionMessageChunk")] + | Annotated[ToolMessageChunk, Tag(tag="ToolMessageChunk")], Field(discriminator=Discriminator(_get_type)), ] @@ -215,18 +210,18 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage: ) -MessageLikeRepresentation = Union[ - BaseMessage, list[str], tuple[str, str], str, dict[str, Any] -] +MessageLikeRepresentation = ( + BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any] +) def _create_message_from_message_type( message_type: str, content: str, - name: Optional[str] = None, - tool_call_id: Optional[str] = None, - tool_calls: Optional[list[dict[str, Any]]] = None, - id: Optional[str] = None, + name: str | None = None, + tool_call_id: str | None = None, + tool_calls: list[dict[str, Any]] | None = None, + id: str | None = None, **additional_kwargs: Any, ) -> BaseMessage: """Create a message from a ``Message`` type and content string. @@ -368,7 +363,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: def convert_to_messages( - messages: Union[Iterable[MessageLikeRepresentation], PromptValue], + messages: Iterable[MessageLikeRepresentation] | PromptValue, ) -> list[BaseMessage]: """Convert a sequence of messages to a list of messages. @@ -399,12 +394,12 @@ def _runnable_support(func: Callable) -> Callable: ) -> list[BaseMessage]: ... def wrapped( - messages: Union[Sequence[MessageLikeRepresentation], None] = None, + messages: Sequence[MessageLikeRepresentation] | None = None, **kwargs: Any, - ) -> Union[ - list[BaseMessage], - Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]], - ]: + ) -> ( + list[BaseMessage] + | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]] + ): # Import locally to prevent circular import. from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415 @@ -418,15 +413,15 @@ def _runnable_support(func: Callable) -> Callable: @_runnable_support def filter_messages( - messages: Union[Iterable[MessageLikeRepresentation], PromptValue], + messages: Iterable[MessageLikeRepresentation] | PromptValue, *, - include_names: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None, - exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None, - include_ids: Optional[Sequence[str]] = None, - exclude_ids: Optional[Sequence[str]] = None, - exclude_tool_calls: Optional[Sequence[str] | bool] = None, + include_names: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + include_types: Sequence[str | type[BaseMessage]] | None = None, + exclude_types: Sequence[str | type[BaseMessage]] | None = None, + include_ids: Sequence[str] | None = None, + exclude_ids: Sequence[str] | None = None, + exclude_tool_calls: Sequence[str] | bool | None = None, ) -> list[BaseMessage]: """Filter messages based on ``name``, ``type`` or ``id``. @@ -563,7 +558,7 @@ def filter_messages( @_runnable_support def merge_message_runs( - messages: Union[Iterable[MessageLikeRepresentation], PromptValue], + messages: Iterable[MessageLikeRepresentation] | PromptValue, *, chunk_separator: str = "\n", ) -> list[BaseMessage]: @@ -696,24 +691,18 @@ def merge_message_runs( # init not at runtime. @_runnable_support def trim_messages( - messages: Union[Iterable[MessageLikeRepresentation], PromptValue], + messages: Iterable[MessageLikeRepresentation] | PromptValue, *, max_tokens: int, - token_counter: Union[ - Callable[[list[BaseMessage]], int], - Callable[[BaseMessage], int], - BaseLanguageModel, - ], + token_counter: Callable[[list[BaseMessage]], int] + | Callable[[BaseMessage], int] + | BaseLanguageModel, strategy: Literal["first", "last"] = "last", allow_partial: bool = False, - end_on: Optional[ - Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] - ] = None, - start_on: Optional[ - Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] - ] = None, + end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, + start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, include_system: bool = False, - text_splitter: Optional[Union[Callable[[str], list[str]], TextSplitter]] = None, + text_splitter: Callable[[str], list[str]] | TextSplitter | None = None, ) -> list[BaseMessage]: r"""Trim messages to be below a token count. @@ -1042,11 +1031,11 @@ def trim_messages( def convert_to_openai_messages( - messages: Union[MessageLikeRepresentation, Sequence[MessageLikeRepresentation]], + messages: MessageLikeRepresentation | Sequence[MessageLikeRepresentation], *, text_format: Literal["string", "block"] = "string", include_id: bool = False, -) -> Union[dict, list[dict]]: +) -> dict | list[dict]: """Convert LangChain messages into OpenAI message dicts. Args: @@ -1143,7 +1132,7 @@ def convert_to_openai_messages( for i, message in enumerate(messages): oai_msg: dict = {"role": _get_message_openai_role(message)} tool_messages: list = [] - content: Union[str, list[dict]] + content: str | list[dict] if message.name: oai_msg["name"] = message.name @@ -1426,10 +1415,8 @@ def _first_max_tokens( max_tokens: int, token_counter: Callable[[list[BaseMessage]], int], text_splitter: Callable[[str], list[str]], - partial_strategy: Optional[Literal["first", "last"]] = None, - end_on: Optional[ - Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] - ] = None, + partial_strategy: Literal["first", "last"] | None = None, + end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, ) -> list[BaseMessage]: messages = list(messages) if not messages: @@ -1546,12 +1533,8 @@ def _last_max_tokens( text_splitter: Callable[[str], list[str]], allow_partial: bool = False, include_system: bool = False, - start_on: Optional[ - Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] - ] = None, - end_on: Optional[ - Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] - ] = None, + start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, + end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, ) -> list[BaseMessage]: messages = list(messages) if len(messages) == 0: @@ -1652,7 +1635,7 @@ def _default_text_splitter(text: str) -> list[str]: def _is_message_type( message: BaseMessage, - type_: Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]], + type_: str | type[BaseMessage] | Sequence[str | type[BaseMessage]], ) -> bool: types = [type_] if isinstance(type_, (str, type)) else type_ types_str = [t for t in types if isinstance(t, str)] diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 6cefd3ef62f..f1582c3f95d 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -8,9 +8,7 @@ from typing import ( TYPE_CHECKING, Any, Generic, - Optional, TypeVar, - Union, ) from typing_extensions import override @@ -71,7 +69,7 @@ class BaseGenerationOutputParser( @override def InputType(self) -> Any: """Return the input type for the parser.""" - return Union[str, AnyMessage] + return str | AnyMessage @property @override @@ -84,8 +82,8 @@ class BaseGenerationOutputParser( @override def invoke( self, - input: Union[str, BaseMessage], - config: Optional[RunnableConfig] = None, + input: str | BaseMessage, + config: RunnableConfig | None = None, **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): @@ -107,9 +105,9 @@ class BaseGenerationOutputParser( @override async def ainvoke( self, - input: Union[str, BaseMessage], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + input: str | BaseMessage, + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> T: if isinstance(input, BaseMessage): return await self._acall_with_config( @@ -165,7 +163,7 @@ class BaseOutputParser( @override def InputType(self) -> Any: """Return the input type for the parser.""" - return Union[str, AnyMessage] + return str | AnyMessage @property @override @@ -192,8 +190,8 @@ class BaseOutputParser( @override def invoke( self, - input: Union[str, BaseMessage], - config: Optional[RunnableConfig] = None, + input: str | BaseMessage, + config: RunnableConfig | None = None, **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): @@ -215,9 +213,9 @@ class BaseOutputParser( @override async def ainvoke( self, - input: Union[str, BaseMessage], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + input: str | BaseMessage, + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> T: if isinstance(input, BaseMessage): return await self._acall_with_config( diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index 0d1513e6d86..3fb7499591b 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from json import JSONDecodeError -from typing import Annotated, Any, Optional, TypeVar, Union +from typing import Annotated, Any, TypeVar import jsonpatch # type: ignore[import-untyped] import pydantic @@ -23,7 +23,7 @@ from langchain_core.utils.json import ( ) # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. -PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] +PydanticBaseModel = BaseModel | pydantic.BaseModel TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) @@ -38,12 +38,12 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): describing the difference between the previous and the current object. """ - pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore[valid-type] + pydantic_object: Annotated[type[TBaseModel] | None, SkipValidation()] = None # type: ignore[valid-type] """The Pydantic object to use for validation. If None, no validation is performed.""" @override - def _diff(self, prev: Optional[Any], next: Any) -> Any: + def _diff(self, prev: Any | None, next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch @staticmethod diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index b0d1ba4bfd4..908098ae15c 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -7,7 +7,7 @@ import re from abc import abstractmethod from collections import deque from io import StringIO -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar from typing_extensions import override @@ -70,9 +70,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): raise NotImplementedError @override - def _transform( - self, input: Iterator[Union[str, BaseMessage]] - ) -> Iterator[list[str]]: + def _transform(self, input: Iterator[str | BaseMessage]) -> Iterator[list[str]]: buffer = "" for chunk in input: if isinstance(chunk, BaseMessage): @@ -105,7 +103,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): @override async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] + self, input: AsyncIterator[str | BaseMessage] ) -> AsyncIterator[list[str]]: buffer = "" async for chunk in input: diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 5518f36a78e..ed1fd22a635 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -3,7 +3,7 @@ import copy import json from types import GenericAlias -from typing import Any, Optional, Union +from typing import Any import jsonpatch # type: ignore[import-untyped] from pydantic import BaseModel, model_validator @@ -74,7 +74,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): return "json_functions" @override - def _diff(self, prev: Optional[Any], next: Any) -> Any: + def _diff(self, prev: Any | None, next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: @@ -217,7 +217,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): """ - pydantic_schema: Union[type[BaseModel], dict[str, type[BaseModel]]] + pydantic_schema: type[BaseModel] | dict[str, type[BaseModel]] """The pydantic schema to parse the output with. If multiple schemas are provided, then the function name will be used to diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 44563e9a3c2..084b1273a3b 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -4,7 +4,7 @@ import copy import json import logging from json import JSONDecodeError -from typing import Annotated, Any, Optional +from typing import Annotated, Any from pydantic import SkipValidation, ValidationError @@ -26,7 +26,7 @@ def parse_tool_call( partial: bool = False, strict: bool = False, return_id: bool = True, -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """Parse a single tool call. Args: @@ -75,7 +75,7 @@ def parse_tool_call( def make_invalid_tool_call( raw_tool_call: dict[str, Any], - error_msg: Optional[str], + error_msg: str | None, ) -> InvalidToolCall: """Create an InvalidToolCall from a raw tool call. diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 1785cbc60fa..cc6190eb7ba 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,7 +1,7 @@ """Output parsers using Pydantic.""" import json -from typing import Annotated, Generic, Optional +from typing import Annotated, Generic import pydantic from pydantic import SkipValidation @@ -44,7 +44,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): def parse_result( self, result: list[Generation], *, partial: bool = False - ) -> Optional[TBaseModel]: + ) -> TBaseModel | None: """Parse the result of an LLM call to a pydantic object. Args: diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 0c864805b93..877f4d97f3b 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -5,8 +5,6 @@ from __future__ import annotations from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) from typing_extensions import override @@ -32,7 +30,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): def _transform( self, - input: Iterator[Union[str, BaseMessage]], + input: Iterator[str | BaseMessage], ) -> Iterator[T]: for chunk in input: if isinstance(chunk, BaseMessage): @@ -42,7 +40,7 @@ class BaseTransformOutputParser(BaseOutputParser[T]): async def _atransform( self, - input: AsyncIterator[Union[str, BaseMessage]], + input: AsyncIterator[str | BaseMessage], ) -> AsyncIterator[T]: async for chunk in input: if isinstance(chunk, BaseMessage): @@ -57,8 +55,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]): @override def transform( self, - input: Iterator[Union[str, BaseMessage]], - config: Optional[RunnableConfig] = None, + input: Iterator[str | BaseMessage], + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[T]: """Transform the input into the output format. @@ -78,8 +76,8 @@ class BaseTransformOutputParser(BaseOutputParser[T]): @override async def atransform( self, - input: AsyncIterator[Union[str, BaseMessage]], - config: Optional[RunnableConfig] = None, + input: AsyncIterator[str | BaseMessage], + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[T]: """Async transform the input into the output format. @@ -108,7 +106,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): def _diff( self, - prev: Optional[T], + prev: T | None, next: T, # noqa: A002 ) -> T: """Convert parsed outputs into a diff format. @@ -125,11 +123,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): raise NotImplementedError @override - def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: + def _transform(self, input: Iterator[str | BaseMessage]) -> Iterator[Any]: prev_parsed = None - acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None + acc_gen: GenerationChunk | ChatGenerationChunk | None = None for chunk in input: - chunk_gen: Union[GenerationChunk, ChatGenerationChunk] + chunk_gen: GenerationChunk | ChatGenerationChunk if isinstance(chunk, BaseMessageChunk): chunk_gen = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): @@ -151,12 +149,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): @override async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] + self, input: AsyncIterator[str | BaseMessage] ) -> AsyncIterator[T]: prev_parsed = None - acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None + acc_gen: GenerationChunk | ChatGenerationChunk | None = None async for chunk in input: - chunk_gen: Union[GenerationChunk, ChatGenerationChunk] + chunk_gen: GenerationChunk | ChatGenerationChunk if isinstance(chunk, BaseMessageChunk): chunk_gen = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index be12760fade..3e0f17791a5 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -5,7 +5,7 @@ import re import xml import xml.etree.ElementTree as ET from collections.abc import AsyncIterator, Iterator -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from xml.etree.ElementTree import TreeBuilder from typing_extensions import override @@ -75,7 +75,7 @@ class _StreamingParser: self.buffer = "" self.xml_started = False - def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]: + def parse(self, chunk: str | BaseMessage) -> Iterator[AddableDict]: """Parse a chunk of text. Args: @@ -149,7 +149,7 @@ class _StreamingParser: class XMLOutputParser(BaseTransformOutputParser): """Parse an output using xml format.""" - tags: Optional[list[str]] = None + tags: list[str] | None = None """Tags to tell the LLM to expect in the XML output. Note this may not be perfect depending on the LLM implementation. @@ -193,7 +193,7 @@ class XMLOutputParser(BaseTransformOutputParser): """Return the format instructions for the XML output.""" return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) - def parse(self, text: str) -> dict[str, Union[str, list[Any]]]: + def parse(self, text: str) -> dict[str, str | list[Any]]: """Parse the output of an LLM call. Args: @@ -240,9 +240,7 @@ class XMLOutputParser(BaseTransformOutputParser): raise OutputParserException(msg, llm_output=text) from e @override - def _transform( - self, input: Iterator[Union[str, BaseMessage]] - ) -> Iterator[AddableDict]: + def _transform(self, input: Iterator[str | BaseMessage]) -> Iterator[AddableDict]: streaming_parser = _StreamingParser(self.parser) for chunk in input: yield from streaming_parser.parse(chunk) @@ -250,7 +248,7 @@ class XMLOutputParser(BaseTransformOutputParser): @override async def _atransform( - self, input: AsyncIterator[Union[str, BaseMessage]] + self, input: AsyncIterator[str | BaseMessage] ) -> AsyncIterator[AddableDict]: streaming_parser = _StreamingParser(self.parser) async for chunk in input: @@ -258,7 +256,7 @@ class XMLOutputParser(BaseTransformOutputParser): yield output streaming_parser.close() - def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]: + def _root_to_dict(self, root: ET.Element) -> dict[str, str | list[Any]]: """Converts xml tree to python dictionary.""" if root.text and bool(re.search(r"\S", root.text)): # If root text contains any non-whitespace character it diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index 2d3f8c54e7b..9955d576593 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Literal from pydantic import model_validator from typing_extensions import Self @@ -82,7 +82,7 @@ class ChatGenerationChunk(ChatGeneration): """Type is used exclusively for serialization purposes.""" def __add__( - self, other: Union[ChatGenerationChunk, list[ChatGenerationChunk]] + self, other: ChatGenerationChunk | list[ChatGenerationChunk] ) -> ChatGenerationChunk: """Concatenate two ``ChatGenerationChunk``s. @@ -123,7 +123,7 @@ class ChatGenerationChunk(ChatGeneration): def merge_chat_generation_chunks( chunks: list[ChatGenerationChunk], -) -> Union[ChatGenerationChunk, None]: +) -> ChatGenerationChunk | None: """Merge a list of ``ChatGenerationChunk``s into a single ``ChatGenerationChunk``. Args: diff --git a/libs/core/langchain_core/outputs/chat_result.py b/libs/core/langchain_core/outputs/chat_result.py index 3e6a5076c8f..7a53830067a 100644 --- a/libs/core/langchain_core/outputs/chat_result.py +++ b/libs/core/langchain_core/outputs/chat_result.py @@ -1,7 +1,5 @@ """Chat result schema.""" -from typing import Optional - from pydantic import BaseModel from langchain_core.outputs.chat_generation import ChatGeneration @@ -26,7 +24,7 @@ class ChatResult(BaseModel): Generations is a list to allow for multiple candidate generations for a single input prompt. """ - llm_output: Optional[dict] = None + llm_output: dict | None = None """For arbitrary LLM provider specific output. This dictionary is a free-form dictionary that can contain any information that the diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index cd06c20e673..2b42d1e8720 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Optional +from typing import Any, Literal from langchain_core.load import Serializable from langchain_core.utils._merge import merge_dicts @@ -28,7 +28,7 @@ class Generation(Serializable): text: str """Generated text output.""" - generation_info: Optional[dict[str, Any]] = None + generation_info: dict[str, Any] | None = None """Raw response from the provider. May include things like the reason for finishing or token log probabilities. diff --git a/libs/core/langchain_core/outputs/llm_result.py b/libs/core/langchain_core/outputs/llm_result.py index 05c41657c3a..b0bfad33d01 100644 --- a/libs/core/langchain_core/outputs/llm_result.py +++ b/libs/core/langchain_core/outputs/llm_result.py @@ -3,7 +3,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Literal, Optional, Union +from typing import Literal from pydantic import BaseModel @@ -21,7 +21,7 @@ class LLMResult(BaseModel): """ generations: list[ - list[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]] + list[Generation | ChatGeneration | GenerationChunk | ChatGenerationChunk] ] """Generated outputs. @@ -36,7 +36,7 @@ class LLMResult(BaseModel): ChatGeneration is a subclass of Generation that has a field for a structured chat message. """ - llm_output: Optional[dict] = None + llm_output: dict | None = None """For arbitrary LLM provider specific output. This dictionary is a free-form dictionary that can contain any information that the @@ -45,7 +45,7 @@ class LLMResult(BaseModel): Users should generally avoid relying on this field and instead rely on accessing relevant information from standardized fields present in AIMessage. """ - run: Optional[list[RunInfo]] = None + run: list[RunInfo] | None = None """List of metadata info for model call for each input. See `langchain_core.outputs.run_info.RunInfo` for details. diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 4b6a92ab826..b3e656f3b66 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -6,17 +6,14 @@ import contextlib import json import typing from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import cached_property from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, Generic, - Optional, TypeVar, - Union, ) import yaml @@ -57,16 +54,16 @@ class BasePromptTemplate( input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006 """A dictionary of the types of the variables the prompt template expects. If not provided, all variables are assumed to be strings.""" - output_parser: Optional[BaseOutputParser] = None + output_parser: BaseOutputParser | None = None """How to parse the output of calling an LLM on this formatted prompt.""" partial_variables: Mapping[str, Any] = Field(default_factory=dict) """A dictionary of the partial variables the prompt template carries. Partial variables populate the template so that you don't need to pass them in every time you call the prompt.""" - metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006 + metadata: typing.Dict[str, Any] | None = None # noqa: UP006 """Metadata to be used for tracing.""" - tags: Optional[list[str]] = None + tags: list[str] | None = None """Tags to be used for tracing.""" @model_validator(mode="after") @@ -123,12 +120,10 @@ class BasePromptTemplate( @override def OutputType(self) -> Any: """Return the output type of the prompt.""" - return Union[StringPromptValue, ChatPromptValueConcrete] + return StringPromptValue | ChatPromptValueConcrete @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """Get the input schema for the prompt. Args: @@ -195,7 +190,7 @@ class BasePromptTemplate( @override def invoke( - self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: dict, config: RunnableConfig | None = None, **kwargs: Any ) -> PromptValue: """Invoke the prompt. @@ -221,7 +216,7 @@ class BasePromptTemplate( @override async def ainvoke( - self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: dict, config: RunnableConfig | None = None, **kwargs: Any ) -> PromptValue: """Async invoke the prompt. @@ -267,7 +262,7 @@ class BasePromptTemplate( """ return self.format_prompt(**kwargs) - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: + def partial(self, **kwargs: str | Callable[[], str]) -> BasePromptTemplate: """Return a partial of the prompt template. Args: @@ -345,7 +340,7 @@ class BasePromptTemplate( prompt_dict["_type"] = self._prompt_type return prompt_dict - def save(self, file_path: Union[Path, str]) -> None: + def save(self, file_path: Path | str) -> None: """Save the prompt. Args: diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index dc9c45f8274..3ff97de2fbb 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -8,10 +8,8 @@ from typing import ( TYPE_CHECKING, Annotated, Any, - Optional, TypedDict, TypeVar, - Union, cast, overload, ) @@ -136,7 +134,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): list. If False then a named argument with name `variable_name` must be passed in, even if the value is an empty list.""" - n_messages: Optional[PositiveInt] = None + n_messages: PositiveInt | None = None """Maximum number of messages to include. If None, then will include all. Defaults to None.""" @@ -231,7 +229,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): cls, template: str, template_format: PromptTemplateFormat = "f-string", - partial_variables: Optional[dict[str, Any]] = None, + partial_variables: dict[str, Any] | None = None, **kwargs: Any, ) -> Self: """Create a class from a string template. @@ -260,7 +258,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): @classmethod def from_template_file( cls, - template_file: Union[str, Path], + template_file: str | Path, **kwargs: Any, ) -> Self: """Create a class from a template file. @@ -380,20 +378,20 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): class _TextTemplateParam(TypedDict, total=False): - text: Union[str, dict] + text: str | dict class _ImageTemplateParam(TypedDict, total=False): - image_url: Union[str, dict] + image_url: str | dict class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" - prompt: Union[ - StringPromptTemplate, - list[Union[StringPromptTemplate, ImagePromptTemplate, DictPromptTemplate]], - ] + prompt: ( + StringPromptTemplate + | list[StringPromptTemplate | ImagePromptTemplate | DictPromptTemplate] + ) """Prompt template.""" additional_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the prompt template.""" @@ -403,13 +401,11 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template( cls: type[Self], - template: Union[ - str, - list[Union[str, _TextTemplateParam, _ImageTemplateParam, dict[str, Any]]], - ], + template: str + | list[str | _TextTemplateParam | _ImageTemplateParam | dict[str, Any]], template_format: PromptTemplateFormat = "f-string", *, - partial_variables: Optional[dict[str, Any]] = None, + partial_variables: dict[str, Any] | None = None, **kwargs: Any, ) -> Self: """Create a class from a string template. @@ -429,7 +425,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): ValueError: If the template is not a string or list of strings. """ if isinstance(template, str): - prompt: Union[StringPromptTemplate, list] = PromptTemplate.from_template( + prompt: StringPromptTemplate | list = PromptTemplate.from_template( template, template_format=template_format, partial_variables=partial_variables, @@ -526,7 +522,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template_file( cls: type[Self], - template_file: Union[str, Path], + template_file: str | Path, input_variables: list[str], **kwargs: Any, ) -> Self: @@ -593,9 +589,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate): - formatted: Union[str, ImageURL, dict[str, Any]] = prompt.format( - **inputs - ) + formatted: str | ImageURL | dict[str, Any] = prompt.format(**inputs) content.append({"type": "text", "text": formatted}) elif isinstance(prompt, ImagePromptTemplate): formatted = prompt.format(**inputs) @@ -625,7 +619,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate): - formatted: Union[str, ImageURL, dict[str, Any]] = await prompt.aformat( + formatted: str | ImageURL | dict[str, Any] = await prompt.aformat( **inputs ) content.append({"type": "text", "text": formatted}) @@ -769,17 +763,14 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 -MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] +MessageLike = BaseMessagePromptTemplate | BaseMessage | BaseChatPromptTemplate -MessageLikeRepresentation = Union[ - MessageLike, - tuple[ - Union[str, type], - Union[str, list[dict], list[object]], - ], - str, - dict[str, Any], -] +MessageLikeRepresentation = ( + MessageLike + | tuple[str | type, str | list[dict] | list[object]] + | str + | dict[str, Any] +) class ChatPromptTemplate(BaseChatPromptTemplate): @@ -1266,9 +1257,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): @overload def __getitem__(self, index: slice) -> ChatPromptTemplate: ... - def __getitem__( - self, index: Union[int, slice] - ) -> Union[MessageLike, ChatPromptTemplate]: + def __getitem__(self, index: int | slice) -> MessageLike | ChatPromptTemplate: """Use to index into the chat template. Returns: @@ -1291,7 +1280,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """Name of prompt type. Used for serialization.""" return "chat" - def save(self, file_path: Union[Path, str]) -> None: + def save(self, file_path: Path | str) -> None: """Save prompt to file. Args: @@ -1315,7 +1304,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): def _create_template_from_message_type( message_type: str, - template: Union[str, list], + template: str | list, template_format: PromptTemplateFormat = "f-string", ) -> BaseMessagePromptTemplate: """Create a message prompt template from a message type and template string. @@ -1387,7 +1376,7 @@ def _create_template_from_message_type( def _convert_to_message_template( message: MessageLikeRepresentation, template_format: PromptTemplateFormat = "f-string", -) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: +) -> BaseMessage | BaseMessagePromptTemplate | BaseChatPromptTemplate: """Instantiate a message from a variety of message formats. The message format can be one of the following: @@ -1410,9 +1399,9 @@ def _convert_to_message_template( ValueError: If 2-tuple does not have 2 elements. """ if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): - message_: Union[ - BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate - ] = message + message_: BaseMessage | BaseMessagePromptTemplate | BaseChatPromptTemplate = ( + message + ) elif isinstance(message, BaseMessage): message_ = message elif isinstance(message, str): diff --git a/libs/core/langchain_core/prompts/dict.py b/libs/core/langchain_core/prompts/dict.py index 135f8021b9d..74fb4c335d8 100644 --- a/libs/core/langchain_core/prompts/dict.py +++ b/libs/core/langchain_core/prompts/dict.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import override @@ -48,7 +48,7 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]): @override def invoke( - self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: dict, config: RunnableConfig | None = None, **kwargs: Any ) -> dict: return self._call_with_config( lambda x: self.format(**x), diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 1708a79b0f5..0f5fe857801 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal from pydantic import ( BaseModel, @@ -33,11 +33,11 @@ if TYPE_CHECKING: class _FewShotPromptTemplateMixin(BaseModel): """Prompt template that contains few shot examples.""" - examples: Optional[list[dict]] = None + examples: list[dict] | None = None """Examples to format into the prompt. Either this or example_selector should be provided.""" - example_selector: Optional[BaseExampleSelector] = None + example_selector: BaseExampleSelector | None = None """ExampleSelector to choose the examples to format into the prompt. Either this or examples should be provided.""" @@ -229,7 +229,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): """Return the prompt type key.""" return "few_shot" - def save(self, file_path: Union[Path, str]) -> None: + def save(self, file_path: Path | str) -> None: """Save the prompt template to a file. Args: @@ -365,7 +365,7 @@ class FewShotChatMessagePromptTemplate( """A list of the names of the variables the prompt template will use to pass to the example_selector, if provided.""" - example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] + example_prompt: BaseMessagePromptTemplate | BaseChatPromptTemplate """The class to format each example.""" @classmethod diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index 5f3ad91fe48..73d4fc8903b 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -1,7 +1,7 @@ """Prompt template that contains few shot examples.""" from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from pydantic import ConfigDict, model_validator from typing_extensions import Self @@ -17,7 +17,7 @@ from langchain_core.prompts.string import ( class FewShotPromptWithTemplates(StringPromptTemplate): """Prompt template that contains few shot examples.""" - examples: Optional[list[dict]] = None + examples: list[dict] | None = None """Examples to format into the prompt. Either this or example_selector should be provided.""" @@ -34,7 +34,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): example_separator: str = "\n\n" """String separator used to join the prefix, the examples, and suffix.""" - prefix: Optional[StringPromptTemplate] = None + prefix: StringPromptTemplate | None = None """A PromptTemplate to put before the examples.""" template_format: PromptTemplateFormat = "f-string" @@ -210,7 +210,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): """Return the prompt type key.""" return "few_shot_with_templates" - def save(self, file_path: Union[Path, str]) -> None: + def save(self, file_path: Path | str) -> None: """Save the prompt to a file. Args: diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index 5ca0878c82b..63e587c8ed2 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -2,8 +2,8 @@ import json import logging +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Union import yaml @@ -134,9 +134,7 @@ def _load_prompt(config: dict) -> PromptTemplate: return PromptTemplate(**config) -def load_prompt( - path: Union[str, Path], encoding: Optional[str] = None -) -> BasePromptTemplate: +def load_prompt(path: str | Path, encoding: str | None = None) -> BasePromptTemplate: """Unified method for loading a prompt from LangChainHub or local fs. Args: @@ -160,7 +158,7 @@ def load_prompt( def _load_prompt_from_file( - file: Union[str, Path], encoding: Optional[str] = None + file: str | Path, encoding: str | None = None ) -> BasePromptTemplate: """Load prompt from file.""" # Convert file to a Path object. diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index b5ab56d7a43..5a16400c8db 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -3,7 +3,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, model_validator from typing_extensions import override @@ -233,8 +233,8 @@ class PromptTemplate(StringPromptTemplate): @classmethod def from_file( cls, - template_file: Union[str, Path], - encoding: Optional[str] = None, + template_file: str | Path, + encoding: str | None = None, **kwargs: Any, ) -> PromptTemplate: """Load a prompt from a file. @@ -256,7 +256,7 @@ class PromptTemplate(StringPromptTemplate): template: str, *, template_format: PromptTemplateFormat = "f-string", - partial_variables: Optional[dict[str, Any]] = None, + partial_variables: dict[str, Any] | None = None, **kwargs: Any, ) -> PromptTemplate: """Load a prompt template from a template. diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 8b147c3c0bc..6841f38c2b1 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -4,8 +4,9 @@ from __future__ import annotations import warnings from abc import ABC +from collections.abc import Callable from string import Formatter -from typing import Any, Callable, Literal +from typing import Any, Literal from pydantic import BaseModel, create_model diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 0d2c5d90b0f..8a4981147c4 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -1,11 +1,8 @@ """Structured prompt template for a language model.""" -from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from typing import ( Any, - Callable, - Optional, - Union, ) from pydantic import BaseModel, Field @@ -31,16 +28,16 @@ from langchain_core.utils import get_pydantic_field_names class StructuredPrompt(ChatPromptTemplate): """Structured prompt template for a language model.""" - schema_: Union[dict, type] + schema_: dict | type """Schema for the structured prompt.""" structured_output_kwargs: dict[str, Any] = Field(default_factory=dict) def __init__( self, messages: Sequence[MessageLikeRepresentation], - schema_: Optional[Union[dict, type[BaseModel]]] = None, + schema_: dict | type[BaseModel] | None = None, *, - structured_output_kwargs: Optional[dict[str, Any]] = None, + structured_output_kwargs: dict[str, Any] | None = None, template_format: PromptTemplateFormat = "f-string", **kwargs: Any, ) -> None: @@ -80,7 +77,7 @@ class StructuredPrompt(ChatPromptTemplate): def from_messages_and_schema( cls, messages: Sequence[MessageLikeRepresentation], - schema: Union[dict, type], + schema: dict | type, **kwargs: Any, ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -127,26 +124,22 @@ class StructuredPrompt(ChatPromptTemplate): @override def __or__( self, - other: Union[ - Runnable[Any, Other], - Callable[[Iterator[Any]], Iterator[Other]], - Callable[[AsyncIterator[Any]], AsyncIterator[Other]], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], + other: Runnable[Any, Other] + | Callable[[Iterator[Any]], Iterator[Other]] + | Callable[[AsyncIterator[Any]], AsyncIterator[Other]] + | Callable[[Any], Other] + | Mapping[str, Runnable[Any, Other] | Callable[[Any], Other] | Any], ) -> RunnableSerializable[dict, Other]: return self.pipe(other) def pipe( self, - *others: Union[ - Runnable[Any, Other], - Callable[[Iterator[Any]], Iterator[Other]], - Callable[[AsyncIterator[Any]], AsyncIterator[Other]], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], - name: Optional[str] = None, + *others: Runnable[Any, Other] + | Callable[[Iterator[Any]], Iterator[Other]] + | Callable[[AsyncIterator[Any]], AsyncIterator[Other]] + | Callable[[Any], Other] + | Mapping[str, Runnable[Any, Other] | Callable[[Any], Other] | Any], + name: str | None = None, ) -> RunnableSerializable[dict, Other]: """Pipe the structured prompt to a language model. diff --git a/libs/core/langchain_core/rate_limiters.py b/libs/core/langchain_core/rate_limiters.py index 072887f9d9b..5f4ae153f90 100644 --- a/libs/core/langchain_core/rate_limiters.py +++ b/libs/core/langchain_core/rate_limiters.py @@ -6,7 +6,6 @@ import abc import asyncio import threading import time -from typing import Optional class BaseRateLimiter(abc.ABC): @@ -163,7 +162,7 @@ class InMemoryRateLimiter(BaseRateLimiter): # at a given time. self._consume_lock = threading.Lock() # The last time we tried to consume tokens. - self.last: Optional[float] = None + self.last: float | None = None self.check_every_n_seconds = check_every_n_seconds def _consume(self) -> bool: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index c015747f138..ceb5eb84973 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -24,7 +24,7 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pydantic import ConfigDict from typing_extensions import Self, TypedDict, override @@ -58,11 +58,11 @@ class LangSmithRetrieverParams(TypedDict, total=False): ls_retriever_name: str """Retriever name.""" - ls_vector_store_provider: Optional[str] + ls_vector_store_provider: str | None """Vector store provider.""" - ls_embedding_provider: Optional[str] + ls_embedding_provider: str | None """Embedding provider.""" - ls_embedding_model: Optional[str] + ls_embedding_model: str | None """Embedding model.""" @@ -137,14 +137,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): _new_arg_supported: bool = False _expects_other_args: bool = False - tags: Optional[list[str]] = None + tags: list[str] | None = None """Optional list of tags associated with the retriever. Defaults to None. These tags will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a retriever with its use case. """ - metadata: Optional[dict[str, Any]] = None + metadata: dict[str, Any] | None = None """Optional metadata associated with the retriever. Defaults to None. This metadata will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. @@ -216,7 +216,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): @override def invoke( - self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: str, config: RunnableConfig | None = None, **kwargs: Any ) -> list[Document]: """Invoke the retriever to get relevant documents. @@ -278,7 +278,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): async def ainvoke( self, input: str, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> list[Document]: """Asynchronously invoke the retriever to get relevant documents. @@ -376,9 +376,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): query: str, *, callbacks: Callbacks = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, **kwargs: Any, ) -> list[Document]: """Retrieve documents relevant to a query. @@ -420,9 +420,9 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): query: str, *, callbacks: Callbacks = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, **kwargs: Any, ) -> list[Document]: """Asynchronously get documents relevant to a query. diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a1e6258e2e8..de09a1be992 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -13,6 +13,7 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, Awaitable, + Callable, Coroutine, Iterator, Mapping, @@ -26,13 +27,10 @@ from types import GenericAlias from typing import ( TYPE_CHECKING, Any, - Callable, Generic, Literal, - Optional, Protocol, TypeVar, - Union, cast, get_args, get_type_hints, @@ -255,12 +253,10 @@ class Runnable(ABC, Generic[Input, Output]): """ - name: Optional[str] + name: str | None """The name of the ``Runnable``. Used for debugging and tracing.""" - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: """Get the name of the ``Runnable``. Args: @@ -365,7 +361,7 @@ class Runnable(ABC, Generic[Input, Output]): def get_input_schema( self, - config: Optional[RunnableConfig] = None, # noqa: ARG002 + config: RunnableConfig | None = None, # noqa: ARG002 ) -> type[BaseModel]: """Get a pydantic model that can be used to validate input to the Runnable. @@ -404,7 +400,7 @@ class Runnable(ABC, Generic[Input, Output]): ) def get_input_jsonschema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> dict[str, Any]: """Get a JSON schema that represents the input to the ``Runnable``. @@ -443,7 +439,7 @@ class Runnable(ABC, Generic[Input, Output]): def get_output_schema( self, - config: Optional[RunnableConfig] = None, # noqa: ARG002 + config: RunnableConfig | None = None, # noqa: ARG002 ) -> type[BaseModel]: """Get a pydantic model that can be used to validate output to the ``Runnable``. @@ -482,7 +478,7 @@ class Runnable(ABC, Generic[Input, Output]): ) def get_output_jsonschema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> dict[str, Any]: """Get a JSON schema that represents the output of the ``Runnable``. @@ -516,9 +512,7 @@ class Runnable(ABC, Generic[Input, Output]): """List configurable fields for this ``Runnable``.""" return [] - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> type[BaseModel]: + def config_schema(self, *, include: Sequence[str] | None = None) -> type[BaseModel]: """The type of config this ``Runnable`` accepts specified as a pydantic model. To mark a field as configurable, see the ``configurable_fields`` @@ -562,7 +556,7 @@ class Runnable(ABC, Generic[Input, Output]): return create_model_v2(self.get_name("Config"), field_definitions=all_fields) def get_config_jsonschema( - self, *, include: Optional[Sequence[str]] = None + self, *, include: Sequence[str] | None = None ) -> dict[str, Any]: """Get a JSON schema that represents the config of the ``Runnable``. @@ -577,7 +571,7 @@ class Runnable(ABC, Generic[Input, Output]): """ return self.config_schema(include=include).model_json_schema() - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: """Return a graph representation of this ``Runnable``.""" # Import locally to prevent circular import from langchain_core.runnables.graph import Graph # noqa: PLC0415 @@ -599,7 +593,7 @@ class Runnable(ABC, Generic[Input, Output]): return graph def get_prompts( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> list[BasePromptTemplate]: """Return a list of prompts used by this ``Runnable``.""" # Import locally to prevent circular import @@ -613,13 +607,11 @@ class Runnable(ABC, Generic[Input, Output]): def __or__( self, - other: Union[ - Runnable[Any, Other], - Callable[[Iterator[Any]], Iterator[Other]], - Callable[[AsyncIterator[Any]], AsyncIterator[Other]], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], + other: Runnable[Any, Other] + | Callable[[Iterator[Any]], Iterator[Other]] + | Callable[[AsyncIterator[Any]], AsyncIterator[Other]] + | Callable[[Any], Other] + | Mapping[str, Runnable[Any, Other] | Callable[[Any], Other] | Any], ) -> RunnableSerializable[Input, Other]: """Runnable "or" operator. @@ -636,13 +628,11 @@ class Runnable(ABC, Generic[Input, Output]): def __ror__( self, - other: Union[ - Runnable[Other, Any], - Callable[[Iterator[Other]], Iterator[Any]], - Callable[[AsyncIterator[Other]], AsyncIterator[Any]], - Callable[[Other], Any], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], - ], + other: Runnable[Other, Any] + | Callable[[Iterator[Other]], Iterator[Any]] + | Callable[[AsyncIterator[Other]], AsyncIterator[Any]] + | Callable[[Other], Any] + | Mapping[str, Runnable[Other, Any] | Callable[[Other], Any] | Any], ) -> RunnableSerializable[Other, Output]: """Runnable "reverse-or" operator. @@ -659,8 +649,8 @@ class Runnable(ABC, Generic[Input, Output]): def pipe( self, - *others: Union[Runnable[Any, Other], Callable[[Any], Other]], - name: Optional[str] = None, + *others: Runnable[Any, Other] | Callable[[Any], Other], + name: str | None = None, ) -> RunnableSerializable[Input, Other]: """Pipe runnables. @@ -706,7 +696,7 @@ class Runnable(ABC, Generic[Input, Output]): """ return RunnableSequence(self, *others, name=name) - def pick(self, keys: Union[str, list[str]]) -> RunnableSerializable[Any, Any]: + def pick(self, keys: str | list[str]) -> RunnableSerializable[Any, Any]: """Pick keys from the output dict of this ``Runnable``. Pick single key: @@ -771,14 +761,9 @@ class Runnable(ABC, Generic[Input, Output]): def assign( self, - **kwargs: Union[ - Runnable[dict[str, Any], Any], - Callable[[dict[str, Any]], Any], - Mapping[ - str, - Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]], - ], - ], + **kwargs: Runnable[dict[str, Any], Any] + | Callable[[dict[str, Any]], Any] + | Mapping[str, Runnable[dict[str, Any], Any] | Callable[[dict[str, Any]], Any]], ) -> RunnableSerializable[Any, Any]: """Assigns new fields to the dict output of this ``Runnable``. @@ -827,7 +812,7 @@ class Runnable(ABC, Generic[Input, Output]): def invoke( self, input: Input, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Output: """Transform a single input into an output. @@ -847,7 +832,7 @@ class Runnable(ABC, Generic[Input, Output]): async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Output: """Transform a single input into an output. @@ -868,10 +853,10 @@ class Runnable(ABC, Generic[Input, Output]): def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: """Default implementation runs invoke in parallel using a thread pool executor. @@ -899,7 +884,7 @@ class Runnable(ABC, Generic[Input, Output]): configs = get_config_list(config, len(inputs)) - def invoke(input_: Input, config: RunnableConfig) -> Union[Output, Exception]: + def invoke(input_: Input, config: RunnableConfig) -> Output | Exception: if return_exceptions: try: return self.invoke(input_, config, **kwargs) @@ -919,7 +904,7 @@ class Runnable(ABC, Generic[Input, Output]): def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -929,20 +914,20 @@ class Runnable(ABC, Generic[Input, Output]): def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[True], **kwargs: Any, - ) -> Iterator[tuple[int, Union[Output, Exception]]]: ... + ) -> Iterator[tuple[int, Output | Exception]]: ... def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> Iterator[tuple[int, Union[Output, Exception]]]: + **kwargs: Any | None, + ) -> Iterator[tuple[int, Output | Exception]]: """Run ``invoke`` in parallel on a list of inputs. Yields results as they complete. @@ -969,12 +954,10 @@ class Runnable(ABC, Generic[Input, Output]): def invoke( i: int, input_: Input, config: RunnableConfig - ) -> tuple[int, Union[Output, Exception]]: + ) -> tuple[int, Output | Exception]: if return_exceptions: try: - out: Union[Output, Exception] = self.invoke( - input_, config, **kwargs - ) + out: Output | Exception = self.invoke(input_, config, **kwargs) except Exception as e: out = e else: @@ -989,7 +972,7 @@ class Runnable(ABC, Generic[Input, Output]): with get_executor_for_config(configs[0]) as executor: futures = { executor.submit(invoke, i, input_, config) - for i, (input_, config) in enumerate(zip(inputs, configs)) + for i, (input_, config) in enumerate(zip(inputs, configs, strict=False)) } try: @@ -1004,10 +987,10 @@ class Runnable(ABC, Generic[Input, Output]): async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: """Default implementation runs ``ainvoke`` in parallel using ``asyncio.gather``. @@ -1036,9 +1019,7 @@ class Runnable(ABC, Generic[Input, Output]): configs = get_config_list(config, len(inputs)) - async def ainvoke( - value: Input, config: RunnableConfig - ) -> Union[Output, Exception]: + async def ainvoke(value: Input, config: RunnableConfig) -> Output | Exception: if return_exceptions: try: return await self.ainvoke(value, config, **kwargs) @@ -1054,30 +1035,30 @@ class Runnable(ABC, Generic[Input, Output]): def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[False] = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> AsyncIterator[tuple[int, Output]]: ... @overload def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[True], - **kwargs: Optional[Any], - ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ... + **kwargs: Any | None, + ) -> AsyncIterator[tuple[int, Output | Exception]]: ... async def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: + **kwargs: Any | None, + ) -> AsyncIterator[tuple[int, Output | Exception]]: """Run ``ainvoke`` in parallel on a list of inputs. Yields results as they complete. @@ -1107,10 +1088,10 @@ class Runnable(ABC, Generic[Input, Output]): async def ainvoke_task( i: int, input_: Input, config: RunnableConfig - ) -> tuple[int, Union[Output, Exception]]: + ) -> tuple[int, Output | Exception]: if return_exceptions: try: - out: Union[Output, Exception] = await self.ainvoke( + out: Output | Exception = await self.ainvoke( input_, config, **kwargs ) except Exception as e: @@ -1123,7 +1104,7 @@ class Runnable(ABC, Generic[Input, Output]): gated_coro(semaphore, ainvoke_task(i, input_, config)) if semaphore else ainvoke_task(i, input_, config) - for i, (input_, config) in enumerate(zip(inputs, configs)) + for i, (input_, config) in enumerate(zip(inputs, configs, strict=False)) ] for coro in asyncio.as_completed(coros): @@ -1132,8 +1113,8 @@ class Runnable(ABC, Generic[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: """Default implementation of ``stream``, which calls ``invoke``. @@ -1153,8 +1134,8 @@ class Runnable(ABC, Generic[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: """Default implementation of ``astream``, which calls ``ainvoke``. @@ -1175,16 +1156,16 @@ class Runnable(ABC, Generic[Input, Output]): def astream_log( self, input: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, diff: Literal[True] = True, with_streamed_output_list: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> AsyncIterator[RunLogPatch]: ... @@ -1192,34 +1173,34 @@ class Runnable(ABC, Generic[Input, Output]): def astream_log( self, input: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, diff: Literal[False], with_streamed_output_list: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> AsyncIterator[RunLog]: ... async def astream_log( self, input: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, diff: bool = True, with_streamed_output_list: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, - ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: + ) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]: """Stream all output from a ``Runnable``, as reported to the callback system. This includes all inner runs of LLMs, Retrievers, Tools, etc. @@ -1275,15 +1256,15 @@ class Runnable(ABC, Generic[Input, Output]): async def astream_events( self, input: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, version: Literal["v1", "v2"] = "v2", - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> AsyncIterator[StreamEvent]: """Generate a stream of events. @@ -1537,8 +1518,8 @@ class Runnable(ABC, Generic[Input, Output]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: """Transform inputs to outputs. @@ -1582,8 +1563,8 @@ class Runnable(ABC, Generic[Input, Output]): async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: """Transform inputs to outputs. @@ -1661,7 +1642,7 @@ class Runnable(ABC, Generic[Input, Output]): def with_config( self, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, # Sadly Unpack is not well-supported by mypy so this will have to be untyped **kwargs: Any, ) -> Runnable[Input, Output]: @@ -1687,15 +1668,15 @@ class Runnable(ABC, Generic[Input, Output]): def with_listeners( self, *, - on_start: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_end: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_error: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, + on_start: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_end: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_error: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, ) -> Runnable[Input, Output]: """Bind lifecycle listeners to a ``Runnable``, returning a new ``Runnable``. @@ -1759,9 +1740,9 @@ class Runnable(ABC, Generic[Input, Output]): def with_alisteners( self, *, - on_start: Optional[AsyncListener] = None, - on_end: Optional[AsyncListener] = None, - on_error: Optional[AsyncListener] = None, + on_start: AsyncListener | None = None, + on_end: AsyncListener | None = None, + on_error: AsyncListener | None = None, ) -> Runnable[Input, Output]: """Bind async lifecycle listeners to a ``Runnable``. @@ -1850,8 +1831,8 @@ class Runnable(ABC, Generic[Input, Output]): def with_types( self, *, - input_type: Optional[type[Input]] = None, - output_type: Optional[type[Output]] = None, + input_type: type[Input] | None = None, + output_type: type[Output] | None = None, ) -> Runnable[Input, Output]: """Bind input and output types to a ``Runnable``, returning a new ``Runnable``. @@ -1874,7 +1855,7 @@ class Runnable(ABC, Generic[Input, Output]): *, retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,), wait_exponential_jitter: bool = True, - exponential_jitter_params: Optional[ExponentialJitterParams] = None, + exponential_jitter_params: ExponentialJitterParams | None = None, stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: """Create a new Runnable that retries the original Runnable on exceptions. @@ -1962,7 +1943,7 @@ class Runnable(ABC, Generic[Input, Output]): fallbacks: Sequence[Runnable[Input, Output]], *, exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,), - exception_key: Optional[str] = None, + exception_key: str | None = None, ) -> RunnableWithFallbacksT[Input, Output]: """Add fallbacks to a ``Runnable``, returning a new ``Runnable``. @@ -2037,16 +2018,14 @@ class Runnable(ABC, Generic[Input, Output]): def _call_with_config( self, - func: Union[ - Callable[[Input], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], + func: Callable[[Input], Output] + | Callable[[Input, CallbackManagerForChainRun], Output] + | Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], input_: Input, - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - serialized: Optional[dict[str, Any]] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None, + run_type: str | None = None, + serialized: dict[str, Any] | None = None, + **kwargs: Any | None, ) -> Output: """Call with config. @@ -2088,19 +2067,16 @@ class Runnable(ABC, Generic[Input, Output]): async def _acall_with_config( self, - func: Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], + func: Callable[[Input], Awaitable[Output]] + | Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]] + | Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output] ], input_: Input, - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - serialized: Optional[dict[str, Any]] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None, + run_type: str | None = None, + serialized: dict[str, Any] | None = None, + **kwargs: Any | None, ) -> Output: """Async call with config. @@ -2134,23 +2110,20 @@ class Runnable(ABC, Generic[Input, Output]): def _batch_with_config( self, - func: Union[ - Callable[[list[Input]], list[Union[Exception, Output]]], - Callable[ - [list[Input], list[CallbackManagerForChainRun]], - list[Union[Exception, Output]], - ], - Callable[ - [list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]], - list[Union[Exception, Output]], - ], + func: Callable[[list[Input]], list[Exception | Output]] + | Callable[ + [list[Input], list[CallbackManagerForChainRun]], list[Exception | Output] + ] + | Callable[ + [list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]], + list[Exception | Output], ], inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - run_type: Optional[str] = None, - **kwargs: Optional[Any], + run_type: str | None = None, + **kwargs: Any | None, ) -> list[Output]: """Transform a list of inputs to a list of outputs, with callbacks. @@ -2172,14 +2145,14 @@ class Runnable(ABC, Generic[Input, Output]): run_id=config.pop("run_id", None), ) for callback_manager, input_, config in zip( - callback_managers, inputs, configs + callback_managers, inputs, configs, strict=False ) ] try: if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) - for c, rm in zip(configs, run_managers) + for c, rm in zip(configs, run_managers, strict=False) ] if accepts_run_manager(func): kwargs["run_manager"] = run_managers @@ -2191,8 +2164,8 @@ class Runnable(ABC, Generic[Input, Output]): return cast("list[Output]", [e for _ in inputs]) raise else: - first_exception: Optional[Exception] = None - for run_manager, out in zip(run_managers, output): + first_exception: Exception | None = None + for run_manager, out in zip(run_managers, output, strict=False): if isinstance(out, Exception): first_exception = first_exception or out run_manager.on_chain_error(out) @@ -2204,27 +2177,21 @@ class Runnable(ABC, Generic[Input, Output]): async def _abatch_with_config( self, - func: Union[ - Callable[[list[Input]], Awaitable[list[Union[Exception, Output]]]], - Callable[ - [list[Input], list[AsyncCallbackManagerForChainRun]], - Awaitable[list[Union[Exception, Output]]], - ], - Callable[ - [ - list[Input], - list[AsyncCallbackManagerForChainRun], - list[RunnableConfig], - ], - Awaitable[list[Union[Exception, Output]]], - ], + func: Callable[[list[Input]], Awaitable[list[Exception | Output]]] + | Callable[ + [list[Input], list[AsyncCallbackManagerForChainRun]], + Awaitable[list[Exception | Output]], + ] + | Callable[ + [list[Input], list[AsyncCallbackManagerForChainRun], list[RunnableConfig]], + Awaitable[list[Exception | Output]], ], inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - run_type: Optional[str] = None, - **kwargs: Optional[Any], + run_type: str | None = None, + **kwargs: Any | None, ) -> list[Output]: """Transform a list of inputs to a list of outputs, with callbacks. @@ -2249,7 +2216,7 @@ class Runnable(ABC, Generic[Input, Output]): run_id=config.pop("run_id", None), ) for callback_manager, input_, config in zip( - callback_managers, inputs, configs + callback_managers, inputs, configs, strict=False ) ) ) @@ -2257,7 +2224,7 @@ class Runnable(ABC, Generic[Input, Output]): if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) - for c, rm in zip(configs, run_managers) + for c, rm in zip(configs, run_managers, strict=False) ] if accepts_run_manager(func): kwargs["run_manager"] = run_managers @@ -2270,9 +2237,9 @@ class Runnable(ABC, Generic[Input, Output]): return cast("list[Output]", [e for _ in inputs]) raise else: - first_exception: Optional[Exception] = None + first_exception: Exception | None = None coros: list[Awaitable[None]] = [] - for run_manager, out in zip(run_managers, output): + for run_manager, out in zip(run_managers, output, strict=False): if isinstance(out, Exception): first_exception = first_exception or out coros.append(run_manager.on_chain_error(out)) @@ -2286,21 +2253,15 @@ class Runnable(ABC, Generic[Input, Output]): def _transform_stream_with_config( self, inputs: Iterator[Input], - transformer: Union[ - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]], - Callable[ - [ - Iterator[Input], - CallbackManagerForChainRun, - RunnableConfig, - ], - Iterator[Output], - ], + transformer: Callable[[Iterator[Input]], Iterator[Output]] + | Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]] + | Callable[ + [Iterator[Input], CallbackManagerForChainRun, RunnableConfig], + Iterator[Output], ], - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None, + run_type: str | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: """Transform a stream with config. @@ -2313,9 +2274,9 @@ class Runnable(ABC, Generic[Input, Output]): # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one - final_input: Optional[Input] = next(input_for_tracing, None) + final_input: Input | None = next(input_for_tracing, None) final_input_supported = True - final_output: Optional[Output] = None + final_output: Output | None = None final_output_supported = True config = ensure_config(config) @@ -2386,24 +2347,18 @@ class Runnable(ABC, Generic[Input, Output]): async def _atransform_stream_with_config( self, inputs: AsyncIterator[Input], - transformer: Union[ - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - Callable[ - [AsyncIterator[Input], AsyncCallbackManagerForChainRun], - AsyncIterator[Output], - ], - Callable[ - [ - AsyncIterator[Input], - AsyncCallbackManagerForChainRun, - RunnableConfig, - ], - AsyncIterator[Output], - ], + transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + | Callable[ + [AsyncIterator[Input], AsyncCallbackManagerForChainRun], + AsyncIterator[Output], + ] + | Callable[ + [AsyncIterator[Input], AsyncCallbackManagerForChainRun, RunnableConfig], + AsyncIterator[Output], ], - config: Optional[RunnableConfig], - run_type: Optional[str] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None, + run_type: str | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: """Transform a stream with config. @@ -2416,9 +2371,9 @@ class Runnable(ABC, Generic[Input, Output]): # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(inputs, 2) # Start the input iterator to ensure the input Runnable starts before this one - final_input: Optional[Input] = await py_anext(input_for_tracing, None) + final_input: Input | None = await py_anext(input_for_tracing, None) final_input_supported = True - final_output: Optional[Output] = None + final_output: Output | None = None final_output_supported = True config = ensure_config(config) @@ -2495,11 +2450,11 @@ class Runnable(ABC, Generic[Input, Output]): @beta_decorator.beta(message="This API is in beta and may change in the future.") def as_tool( self, - args_schema: Optional[type[BaseModel]] = None, + args_schema: type[BaseModel] | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - arg_types: Optional[dict[str, type]] = None, + name: str | None = None, + description: str | None = None, + arg_types: dict[str, type] | None = None, ) -> BaseTool: """Create a ``BaseTool`` from a ``Runnable``. @@ -2614,7 +2569,7 @@ class Runnable(ABC, Generic[Input, Output]): class RunnableSerializable(Serializable, Runnable[Input, Output]): """Runnable that can be serialized to JSON.""" - name: Optional[str] = None + name: str | None = None model_config = ConfigDict( # Suppress warnings from pydantic protected namespaces @@ -2623,7 +2578,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): ) @override - def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: + def to_json(self) -> SerializedConstructor | SerializedNotImplemented: """Serialize the ``Runnable`` to JSON. Returns: @@ -2695,7 +2650,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): *, default_key: str = "default", prefix_keys: bool = False, - **kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], + **kwargs: Runnable[Input, Output] | Callable[[], Runnable[Input, Output]], ) -> RunnableSerializable[Input, Output]: """Configure alternatives for ``Runnables`` that can be set at runtime. @@ -2751,7 +2706,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): def _seq_input_schema( - steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] + steps: list[Runnable[Any, Any]], config: RunnableConfig | None ) -> type[BaseModel]: # Import locally to prevent circular import from langchain_core.runnables.passthrough import ( # noqa: PLC0415 @@ -2781,7 +2736,7 @@ def _seq_input_schema( def _seq_output_schema( - steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] + steps: list[Runnable[Any, Any]], config: RunnableConfig | None ) -> type[BaseModel]: # Import locally to prevent circular import from langchain_core.runnables.passthrough import ( # noqa: PLC0415 @@ -2928,10 +2883,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def __init__( self, *steps: RunnableLike, - name: Optional[str] = None, - first: Optional[Runnable[Any, Any]] = None, - middle: Optional[list[Runnable[Any, Any]]] = None, - last: Optional[Runnable[Any, Any]] = None, + name: str | None = None, + first: Runnable[Any, Any] | None = None, + middle: list[Runnable[Any, Any]] | None = None, + last: Runnable[Any, Any] | None = None, ) -> None: """Create a new ``RunnableSequence``. @@ -3005,9 +2960,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return self.last.OutputType @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """Get the input schema of the ``Runnable``. Args: @@ -3021,7 +2974,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: """Get the output schema of the ``Runnable``. @@ -3049,7 +3002,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) @override - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: """Get the graph representation of the ``Runnable``. Args: @@ -3092,13 +3045,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]): @override def __or__( self, - other: Union[ - Runnable[Any, Other], - Callable[[Iterator[Any]], Iterator[Other]], - Callable[[AsyncIterator[Any]], AsyncIterator[Other]], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], - ], + other: Runnable[Any, Other] + | Callable[[Iterator[Any]], Iterator[Other]] + | Callable[[AsyncIterator[Any]], AsyncIterator[Other]] + | Callable[[Any], Other] + | Mapping[str, Runnable[Any, Other] | Callable[[Any], Other] | Any], ) -> RunnableSerializable[Input, Other]: if isinstance(other, RunnableSequence): return RunnableSequence( @@ -3121,13 +3072,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]): @override def __ror__( self, - other: Union[ - Runnable[Other, Any], - Callable[[Iterator[Other]], Iterator[Any]], - Callable[[AsyncIterator[Other]], AsyncIterator[Any]], - Callable[[Other], Any], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], - ], + other: Runnable[Other, Any] + | Callable[[Iterator[Other]], Iterator[Any]] + | Callable[[AsyncIterator[Other]], AsyncIterator[Any]] + | Callable[[Other], Any] + | Mapping[str, Runnable[Other, Any] | Callable[[Other], Any] | Any], ) -> RunnableSerializable[Other, Output]: if isinstance(other, RunnableSequence): return RunnableSequence( @@ -3149,7 +3098,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: # setup callbacks and context config = ensure_config(config) @@ -3187,8 +3136,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: # setup callbacks and context config = ensure_config(config) @@ -3227,10 +3176,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if not inputs: return [] @@ -3257,7 +3206,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input_, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip( + callback_managers, inputs, configs, strict=False + ) ] # invoke @@ -3277,7 +3228,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): inputs = step.batch( [ inp - for i, inp in zip(remaining_idxs, inputs) + for i, inp in zip(remaining_idxs, inputs, strict=False) if i not in failed_inputs_map ], [ @@ -3286,7 +3237,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config, callbacks=rm.get_child(f"seq:step:{stepidx + 1}"), ) - for i, (rm, config) in enumerate(zip(run_managers, configs)) + for i, (rm, config) in enumerate( + zip(run_managers, configs, strict=False) + ) if i not in failed_inputs_map ], return_exceptions=return_exceptions, @@ -3296,7 +3249,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): failed_inputs_map.update( { i: inp - for i, inp in zip(remaining_idxs, inputs) + for i, inp in zip(remaining_idxs, inputs, strict=False) if isinstance(inp, Exception) } ) @@ -3322,7 +3275,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): patch_config( config, callbacks=rm.get_child(f"seq:step:{i + 1}") ) - for rm, config in zip(run_managers, configs) + for rm, config in zip(run_managers, configs, strict=False) ], return_exceptions=return_exceptions, **(kwargs if i == 0 else {}), @@ -3336,8 +3289,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return cast("list[Output]", [e for _ in inputs]) raise else: - first_exception: Optional[Exception] = None - for run_manager, out in zip(run_managers, inputs): + first_exception: Exception | None = None + for run_manager, out in zip(run_managers, inputs, strict=False): if isinstance(out, Exception): first_exception = first_exception or out run_manager.on_chain_error(out) @@ -3351,10 +3304,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if not inputs: return [] @@ -3382,7 +3335,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input_, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip( + callback_managers, inputs, configs, strict=False + ) ) ) @@ -3404,7 +3359,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): inputs = await step.abatch( [ inp - for i, inp in zip(remaining_idxs, inputs) + for i, inp in zip(remaining_idxs, inputs, strict=False) if i not in failed_inputs_map ], [ @@ -3413,7 +3368,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config, callbacks=rm.get_child(f"seq:step:{stepidx + 1}"), ) - for i, (rm, config) in enumerate(zip(run_managers, configs)) + for i, (rm, config) in enumerate( + zip(run_managers, configs, strict=False) + ) if i not in failed_inputs_map ], return_exceptions=return_exceptions, @@ -3423,7 +3380,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): failed_inputs_map.update( { i: inp - for i, inp in zip(remaining_idxs, inputs) + for i, inp in zip(remaining_idxs, inputs, strict=False) if isinstance(inp, Exception) } ) @@ -3449,7 +3406,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): patch_config( config, callbacks=rm.get_child(f"seq:step:{i + 1}") ) - for rm, config in zip(run_managers, configs) + for rm, config in zip(run_managers, configs, strict=False) ], return_exceptions=return_exceptions, **(kwargs if i == 0 else {}), @@ -3461,9 +3418,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return cast("list[Output]", [e for _ in inputs]) raise else: - first_exception: Optional[Exception] = None + first_exception: Exception | None = None coros: list[Awaitable[None]] = [] - for run_manager, out in zip(run_managers, inputs): + for run_manager, out in zip(run_managers, inputs, strict=False): if isinstance(out, Exception): first_exception = first_exception or out coros.append(run_manager.on_chain_error(out)) @@ -3526,8 +3483,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: yield from self._transform_stream_with_config( input, @@ -3540,8 +3497,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: yield from self.transform(iter([input]), config, **kwargs) @@ -3549,8 +3506,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: async for chunk in self._atransform_stream_with_config( input, @@ -3564,8 +3521,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -3662,21 +3619,16 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): def __init__( self, - steps__: Optional[ - Mapping[ - str, - Union[ - Runnable[Input, Any], - Callable[[Input], Any], - Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], - ], - ] - ] = None, - **kwargs: Union[ - Runnable[Input, Any], - Callable[[Input], Any], - Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]], - ], + steps__: Mapping[ + str, + Runnable[Input, Any] + | Callable[[Input], Any] + | Mapping[str, Runnable[Input, Any] | Callable[[Input], Any]], + ] + | None = None, + **kwargs: Runnable[Input, Any] + | Callable[[Input], Any] + | Mapping[str, Runnable[Input, Any] | Callable[[Input], Any]], ) -> None: """Create a ``RunnableParallel``. @@ -3712,9 +3664,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ) @override - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: """Get the name of the ``Runnable``. Args: @@ -3739,9 +3689,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): return Any @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """Get the input schema of the ``Runnable``. Args: @@ -3771,7 +3719,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: """Get the output schema of the ``Runnable``. @@ -3799,7 +3747,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ) @override - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: """Get the graph representation of the ``Runnable``. Args: @@ -3847,7 +3795,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> dict[str, Any]: # setup callbacks config = ensure_config(config) @@ -3893,7 +3841,10 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): executor.submit(_invoke_step, step, input, config, key) for key, step in steps.items() ] - output = {key: future.result() for key, future in zip(steps, futures)} + output = { + key: future.result() + for key, future in zip(steps, futures, strict=False) + } # finish the root run except BaseException as e: run_manager.on_chain_error(e) @@ -3906,8 +3857,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> dict[str, Any]: # setup callbacks config = ensure_config(config) @@ -3948,7 +3899,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): for key, step in steps.items() ) ) - output = dict(zip(steps, results)) + output = dict(zip(steps, results, strict=False)) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) @@ -4008,7 +3959,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( @@ -4019,8 +3970,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[dict[str, Any]]: yield from self.transform(iter([input]), config) @@ -4050,7 +4001,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): ] # Wrap in a coroutine to satisfy linter - async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: + async def get_next_chunk(generator: AsyncIterator) -> Output | None: return await py_anext(generator) # Start the first iteration of each generator @@ -4079,7 +4030,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( @@ -4091,8 +4042,8 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[dict[str, Any]]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -4199,15 +4150,12 @@ class RunnableGenerator(Runnable[Input, Output]): def __init__( self, - transform: Union[ - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - ], - atransform: Optional[ - Callable[[AsyncIterator[Input]], AsyncIterator[Output]] - ] = None, + transform: Callable[[Iterator[Input]], Iterator[Output]] + | Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + atransform: Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + | None = None, *, - name: Optional[str] = None, + name: str | None = None, ) -> None: """Initialize a ``RunnableGenerator``. @@ -4256,9 +4204,7 @@ class RunnableGenerator(Runnable[Input, Output]): return Any @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: # Override the default implementation. # For a runnable generator, we need to bring to provide the # module of the underlying function when creating the model. @@ -4299,7 +4245,7 @@ class RunnableGenerator(Runnable[Input, Output]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: # Override the default implementation. # For a runnable generator, we need to bring to provide the @@ -4344,7 +4290,7 @@ class RunnableGenerator(Runnable[Input, Output]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[Output]: if not hasattr(self, "_transform"): @@ -4361,16 +4307,16 @@ class RunnableGenerator(Runnable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[Output]: return self.transform(iter([input]), config, **kwargs) @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: - final: Optional[Output] = None + final: Output | None = None for output in self.stream(input, config, **kwargs): final = output if final is None else final + output # type: ignore[operator] return cast("Output", final) @@ -4379,7 +4325,7 @@ class RunnableGenerator(Runnable[Input, Output]): def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[Output]: if not hasattr(self, "_atransform"): @@ -4394,7 +4340,7 @@ class RunnableGenerator(Runnable[Input, Output]): def astream( self, input: Input, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[Output]: async def input_aiter() -> AsyncIterator[Input]: @@ -4404,9 +4350,9 @@ class RunnableGenerator(Runnable[Input, Output]): @override async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: - final: Optional[Output] = None + final: Output | None = None async for output in self.astream(input, config, **kwargs): final = output if final is None else final + output # type: ignore[operator] return cast("Output", final) @@ -4462,39 +4408,28 @@ class RunnableLambda(Runnable[Input, Output]): def __init__( self, - func: Union[ - Union[ - Callable[[Input], Iterator[Output]], - Callable[[Input], Runnable[Input, Output]], - Callable[[Input], Output], - Callable[[Input, RunnableConfig], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], - Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input], AsyncIterator[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ], + func: Callable[[Input], Iterator[Output]] + | Callable[[Input], Runnable[Input, Output]] + | Callable[[Input], Output] + | Callable[[Input, RunnableConfig], Output] + | Callable[[Input, CallbackManagerForChainRun], Output] + | Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output] + | Callable[[Input], Awaitable[Output]] + | Callable[[Input], AsyncIterator[Output]] + | Callable[[Input, RunnableConfig], Awaitable[Output]] + | Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]] + | Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output] ], - afunc: Optional[ - Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input], AsyncIterator[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], - ] - ] = None, - name: Optional[str] = None, + afunc: Callable[[Input], Awaitable[Output]] + | Callable[[Input], AsyncIterator[Output]] + | Callable[[Input, RunnableConfig], Awaitable[Output]] + | Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]] + | Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output] + ] + | None = None, + name: str | None = None, ) -> None: """Create a ``RunnableLambda`` from a callable, and async callable or both. @@ -4544,7 +4479,7 @@ class RunnableLambda(Runnable[Input, Output]): except AttributeError: pass - self._repr: Optional[str] = None + self._repr: str | None = None @property @override @@ -4561,9 +4496,7 @@ class RunnableLambda(Runnable[Input, Output]): return Any @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """The pydantic schema for the input to this ``Runnable``. Args: @@ -4632,7 +4565,7 @@ class RunnableLambda(Runnable[Input, Output]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: # Override the default implementation. # For a runnable lambda, we need to bring to provide the @@ -4753,7 +4686,7 @@ class RunnableLambda(Runnable[Input, Output]): **kwargs: Any, ) -> Output: if inspect.isgeneratorfunction(self.func): - output: Optional[Output] = None + output: Output | None = None for chunk in call_func_with_variable_args( cast("Callable[[Input], Iterator[Output]]", self.func), input_, @@ -4808,7 +4741,7 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, **kwargs: Any, ) -> Output: - output: Optional[Output] = None + output: Output | None = None for chunk in call_func_with_variable_args( cast("Callable[[Input], Iterator[Output]]", self.func), value, @@ -4844,7 +4777,7 @@ class RunnableLambda(Runnable[Input, Output]): afunc = f if is_async_generator(afunc): - output: Optional[Output] = None + output: Output | None = None async with aclosing( cast( "AsyncGenerator[Any, Any]", @@ -4894,8 +4827,8 @@ class RunnableLambda(Runnable[Input, Output]): def invoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: """Invoke this ``Runnable`` synchronously. @@ -4925,8 +4858,8 @@ class RunnableLambda(Runnable[Input, Output]): async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: """Invoke this ``Runnable`` asynchronously. @@ -4970,7 +4903,7 @@ class RunnableLambda(Runnable[Input, Output]): final = ichunk if inspect.isgeneratorfunction(self.func): - output: Optional[Output] = None + output: Output | None = None for chunk in call_func_with_variable_args( self.func, final, config, run_manager, **kwargs ): @@ -5012,8 +4945,8 @@ class RunnableLambda(Runnable[Input, Output]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: if hasattr(self, "func"): yield from self._transform_stream_with_config( @@ -5033,8 +4966,8 @@ class RunnableLambda(Runnable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: return self.transform(iter([input]), config, **kwargs) @@ -5088,7 +5021,7 @@ class RunnableLambda(Runnable[Input, Output]): afunc = f if is_async_generator(afunc): - output: Optional[Output] = None + output: Output | None = None async for chunk in cast( "AsyncIterator[Output]", acall_func_with_variable_args( @@ -5141,8 +5074,8 @@ class RunnableLambda(Runnable[Input, Output]): async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: async for output in self._atransform_stream_with_config( input, @@ -5156,8 +5089,8 @@ class RunnableLambda(Runnable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -5190,9 +5123,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): return list[self.bound.InputType] # type: ignore[name-defined] @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: return create_model_v2( self.get_name("Input"), root=( @@ -5216,7 +5147,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: schema = self.bound.get_output_schema(config) return create_model_v2( @@ -5238,7 +5169,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): return self.bound.config_specs @override - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: return self.bound.get_graph(config) @classmethod @@ -5271,7 +5202,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): @override def invoke( - self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: list[Input], config: RunnableConfig | None = None, **kwargs: Any ) -> list[Output]: return self._call_with_config(self._invoke, input, config, **kwargs) @@ -5289,7 +5220,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): @override async def ainvoke( - self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: list[Input], config: RunnableConfig | None = None, **kwargs: Any ) -> list[Output]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) @@ -5297,8 +5228,8 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): async def astream_events( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[StreamEvent]: def _error_stream_event(message: str) -> StreamEvent: raise NotImplementedError(message) @@ -5340,9 +5271,7 @@ class RunnableEach(RunnableEachBase[Input, Output]): """ @override - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: name = name or self.name or f"RunnableEach<{self.bound.get_name()}>" return super().get_name(suffix, name=name) @@ -5352,7 +5281,7 @@ class RunnableEach(RunnableEachBase[Input, Output]): @override def with_config( - self, config: Optional[RunnableConfig] = None, **kwargs: Any + self, config: RunnableConfig | None = None, **kwargs: Any ) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.with_config(config, **kwargs)) @@ -5360,15 +5289,15 @@ class RunnableEach(RunnableEachBase[Input, Output]): def with_listeners( self, *, - on_start: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_end: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_error: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, + on_start: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_end: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_error: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, ) -> RunnableEach[Input, Output]: """Bind lifecycle listeners to a ``Runnable``, returning a new ``Runnable``. @@ -5397,9 +5326,9 @@ class RunnableEach(RunnableEachBase[Input, Output]): def with_alisteners( self, *, - on_start: Optional[AsyncListener] = None, - on_end: Optional[AsyncListener] = None, - on_error: Optional[AsyncListener] = None, + on_start: AsyncListener | None = None, + on_end: AsyncListener | None = None, + on_error: AsyncListener | None = None, ) -> RunnableEach[Input, Output]: """Bind async lifecycle listeners to a ``Runnable``. @@ -5459,13 +5388,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ """The config factories to bind to the underlying ``Runnable``.""" # Union[Type[Input], BaseModel] + things like list[str] - custom_input_type: Optional[Any] = None + custom_input_type: Any | None = None """Override the input type of the underlying ``Runnable`` with a custom type. The type can be a pydantic model, or a type annotation (e.g., ``list[str]``). """ # Union[Type[Output], BaseModel] + things like list[str] - custom_output_type: Optional[Any] = None + custom_output_type: Any | None = None """Override the output type of the underlying ``Runnable`` with a custom type. The type can be a pydantic model, or a type annotation (e.g., ``list[str]``). @@ -5479,13 +5408,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ self, *, bound: Runnable[Input, Output], - kwargs: Optional[Mapping[str, Any]] = None, - config: Optional[RunnableConfig] = None, - config_factories: Optional[ - list[Callable[[RunnableConfig], RunnableConfig]] - ] = None, - custom_input_type: Optional[Union[type[Input], BaseModel]] = None, - custom_output_type: Optional[Union[type[Output], BaseModel]] = None, + kwargs: Mapping[str, Any] | None = None, + config: RunnableConfig | None = None, + config_factories: list[Callable[[RunnableConfig], RunnableConfig]] + | None = None, + custom_input_type: type[Input] | BaseModel | None = None, + custom_output_type: type[Output] | BaseModel | None = None, **other_kwargs: Any, ) -> None: """Create a ``RunnableBinding`` from a ``Runnable`` and kwargs. @@ -5523,9 +5451,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ self.config = config or {} @override - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: return self.bound.get_name(suffix, name=name) @property @@ -5547,16 +5473,14 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ ) @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: if self.custom_input_type is not None: return super().get_input_schema(config) return self.bound.get_input_schema(merge_configs(self.config, config)) @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: if self.custom_output_type is not None: return super().get_output_schema(config) @@ -5568,7 +5492,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ return self.bound.config_specs @override - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: return self.bound.get_graph(self._merge_configs(config)) @classmethod @@ -5587,7 +5511,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ """ return ["langchain", "schema", "runnable"] - def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: + def _merge_configs(self, *configs: RunnableConfig | None) -> RunnableConfig: config = merge_configs(self.config, *configs) return merge_configs(config, *(f(config) for f in self.config_factories)) @@ -5595,8 +5519,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def invoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: return self.bound.invoke( input, @@ -5608,8 +5532,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: return await self.bound.ainvoke( input, @@ -5621,10 +5545,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if isinstance(config, list): configs = cast( @@ -5644,10 +5568,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if isinstance(config, list): configs = cast( @@ -5667,7 +5591,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[False] = False, **kwargs: Any, @@ -5677,21 +5601,21 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[True], **kwargs: Any, - ) -> Iterator[tuple[int, Union[Output, Exception]]]: ... + ) -> Iterator[tuple[int, Output | Exception]]: ... @override def batch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> Iterator[tuple[int, Union[Output, Exception]]]: + **kwargs: Any | None, + ) -> Iterator[tuple[int, Output | Exception]]: if isinstance(config, Sequence): configs = cast( "list[RunnableConfig]", @@ -5719,31 +5643,31 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[False] = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> AsyncIterator[tuple[int, Output]]: ... @overload def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: Literal[True], - **kwargs: Optional[Any], - ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ... + **kwargs: Any | None, + ) -> AsyncIterator[tuple[int, Output | Exception]]: ... @override async def abatch_as_completed( self, inputs: Sequence[Input], - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + config: RunnableConfig | Sequence[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: + **kwargs: Any | None, + ) -> AsyncIterator[tuple[int, Output | Exception]]: if isinstance(config, Sequence): configs = cast( "list[RunnableConfig]", @@ -5772,8 +5696,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: yield from self.bound.stream( input, @@ -5785,8 +5709,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, @@ -5799,8 +5723,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ async def astream_events( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[StreamEvent]: async for item in self.bound.astream_events( input, self._merge_configs(config), **{**self.kwargs, **kwargs} @@ -5811,7 +5735,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[Output]: yield from self.bound.transform( @@ -5824,7 +5748,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[ async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[Output]: async for item in self.bound.atransform( @@ -5911,7 +5835,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re @override def with_config( self, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, # Sadly Unpack is not well supported by mypy so this will have to be untyped **kwargs: Any, ) -> Runnable[Input, Output]: @@ -5928,15 +5852,15 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re def with_listeners( self, *, - on_start: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_end: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_error: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, + on_start: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_end: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_error: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, ) -> Runnable[Input, Output]: """Bind lifecycle listeners to a ``Runnable``, returning a new ``Runnable``. @@ -5980,8 +5904,8 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-re @override def with_types( self, - input_type: Optional[Union[type[Input], BaseModel]] = None, - output_type: Optional[Union[type[Output], BaseModel]] = None, + input_type: type[Input] | BaseModel | None = None, + output_type: type[Output] | BaseModel | None = None, ) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, @@ -6065,18 +5989,18 @@ class _RunnableCallableAsyncIterator(Protocol[Input, Output]): ) -> AsyncIterator[Output]: ... -RunnableLike = Union[ - Runnable[Input, Output], - Callable[[Input], Output], - Callable[[Input], Awaitable[Output]], - Callable[[Iterator[Input]], Iterator[Output]], - Callable[[AsyncIterator[Input]], AsyncIterator[Output]], - _RunnableCallableSync[Input, Output], - _RunnableCallableAsync[Input, Output], - _RunnableCallableIterator[Input, Output], - _RunnableCallableAsyncIterator[Input, Output], - Mapping[str, Any], -] +RunnableLike = ( + Runnable[Input, Output] + | Callable[[Input], Output] + | Callable[[Input], Awaitable[Output]] + | Callable[[Iterator[Input]], Iterator[Output]] + | Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + | _RunnableCallableSync[Input, Output] + | _RunnableCallableAsync[Input, Output] + | _RunnableCallableIterator[Input, Output] + | _RunnableCallableAsyncIterator[Input, Output] + | Mapping[str, Any] +) def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: @@ -6131,12 +6055,10 @@ def chain( def chain( - func: Union[ - Callable[[Input], Output], - Callable[[Input], Iterator[Output]], - Callable[[Input], Coroutine[Any, Any, Output]], - Callable[[Input], AsyncIterator[Output]], - ], + func: Callable[[Input], Output] + | Callable[[Input], Iterator[Output]] + | Callable[[Input], Coroutine[Any, Any, Output]] + | Callable[[Input], AsyncIterator[Output]], ) -> Runnable[Input, Output]: """Decorate a function to make it a ``Runnable``. diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 52735bef8d0..1adb42a8d48 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -1,11 +1,15 @@ """Runnable that selects which branch to run based on a condition.""" -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Iterator, + Mapping, + Sequence, +) from typing import ( Any, - Callable, - Optional, - Union, cast, ) @@ -67,17 +71,13 @@ class RunnableBranch(RunnableSerializable[Input, Output]): def __init__( self, - *branches: Union[ - tuple[ - Union[ - Runnable[Input, bool], - Callable[[Input], bool], - Callable[[Input], Awaitable[bool]], - ], - RunnableLike, - ], - RunnableLike, # To accommodate the default branch - ], + *branches: tuple[ + Runnable[Input, bool] + | Callable[[Input], bool] + | Callable[[Input], Awaitable[bool]], + RunnableLike, + ] + | RunnableLike, ) -> None: """A Runnable that runs one of two branches based on a condition. @@ -154,9 +154,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): return ["langchain", "schema", "runnable"] @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] @@ -187,7 +185,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: """First evaluates the condition, then delegate to true or false branch. @@ -246,7 +244,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): @override async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) @@ -296,8 +294,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: """First evaluates the condition, then delegate to true or false branch. @@ -317,7 +315,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - final_output: Optional[Output] = None + final_output: Output | None = None final_output_supported = True try: @@ -380,8 +378,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: """First evaluates the condition, then delegate to true or false branch. @@ -401,7 +399,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - final_output: Optional[Output] = None + final_output: Output | None = None final_output_supported = True try: diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 02707d62e30..a8ffcd7bf12 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -5,7 +5,7 @@ from __future__ import annotations import asyncio import uuid import warnings -from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Callable, Generator, Iterable, Iterator, Sequence from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager from contextvars import Context, ContextVar, Token, copy_context @@ -13,11 +13,8 @@ from functools import partial from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, ParamSpec, TypeVar, - Union, cast, ) @@ -42,7 +39,7 @@ if TYPE_CHECKING: else: # Pydantic validates through typed dicts, but # the callbacks need forward refs updated - Callbacks = Optional[Union[list, Any]] + Callbacks = list | Any | None class EmptyDict(TypedDict, total=False): @@ -75,7 +72,7 @@ class RunnableConfig(TypedDict, total=False): Name for the tracer run for this call. Defaults to the name of the class. """ - max_concurrency: Optional[int] + max_concurrency: int | None """ Maximum number of parallel calls to make. If not provided, defaults to ThreadPoolExecutor's default. @@ -94,7 +91,7 @@ class RunnableConfig(TypedDict, total=False): configurable. """ - run_id: Optional[uuid.UUID] + run_id: uuid.UUID | None """ Unique identifier for the tracer run for this call. If not provided, a new UUID will be generated. @@ -130,7 +127,7 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar( # This is imported and used in langgraph, so don't break. def _set_config_context( config: RunnableConfig, -) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]: +) -> tuple[Token[RunnableConfig | None], dict[str, Any] | None]: """Set the child Runnable config + tracing context. Args: @@ -192,7 +189,7 @@ def set_config_context(config: RunnableConfig) -> Generator[Context, None, None] ) -def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: +def ensure_config(config: RunnableConfig | None = None) -> RunnableConfig: """Ensure that a config is a dict with all keys present. Args: @@ -247,7 +244,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def get_config_list( - config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int + config: RunnableConfig | Sequence[RunnableConfig] | None, length: int ) -> list[RunnableConfig]: """Get a list of configs from a single config or a list of configs. @@ -294,13 +291,13 @@ def get_config_list( def patch_config( - config: Optional[RunnableConfig], + config: RunnableConfig | None, *, - callbacks: Optional[BaseCallbackManager] = None, - recursion_limit: Optional[int] = None, - max_concurrency: Optional[int] = None, - run_name: Optional[str] = None, - configurable: Optional[dict[str, Any]] = None, + callbacks: BaseCallbackManager | None = None, + recursion_limit: int | None = None, + max_concurrency: int | None = None, + run_name: str | None = None, + configurable: dict[str, Any] | None = None, ) -> RunnableConfig: """Patch a config with new values. @@ -339,7 +336,7 @@ def patch_config( return config -def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: +def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig: """Merge multiple configs into one. Args: @@ -406,15 +403,13 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: def call_func_with_variable_args( - func: Union[ - Callable[[Input], Output], - Callable[[Input, RunnableConfig], Output], - Callable[[Input, CallbackManagerForChainRun], Output], - Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], - ], + func: Callable[[Input], Output] + | Callable[[Input, RunnableConfig], Output] + | Callable[[Input, CallbackManagerForChainRun], Output] + | Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], input: Input, config: RunnableConfig, - run_manager: Optional[CallbackManagerForChainRun] = None, + run_manager: CallbackManagerForChainRun | None = None, **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config. @@ -440,18 +435,15 @@ def call_func_with_variable_args( def acall_func_with_variable_args( - func: Union[ - Callable[[Input], Awaitable[Output]], - Callable[[Input, RunnableConfig], Awaitable[Output]], - Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], - Callable[ - [Input, AsyncCallbackManagerForChainRun, RunnableConfig], - Awaitable[Output], - ], + func: Callable[[Input], Awaitable[Output]] + | Callable[[Input, RunnableConfig], Awaitable[Output]] + | Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]] + | Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output] ], input: Input, config: RunnableConfig, - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + run_manager: AsyncCallbackManagerForChainRun | None = None, **kwargs: Any, ) -> Awaitable[Output]: """Async call function that may optionally accept a run_manager and/or config. @@ -571,7 +563,7 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor): @contextmanager def get_executor_for_config( - config: Optional[RunnableConfig], + config: RunnableConfig | None, ) -> Generator[Executor, None, None]: """Get an executor for a config. @@ -589,7 +581,7 @@ def get_executor_for_config( async def run_in_executor( - executor_or_config: Optional[Union[Executor, RunnableConfig]], + executor_or_config: Executor | RunnableConfig | None, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 0485963e0e3..6f6cab8fa12 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -7,6 +7,7 @@ import threading from abc import abstractmethod from collections.abc import ( AsyncIterator, + Callable, Iterator, Sequence, ) @@ -14,9 +15,6 @@ from functools import wraps from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, cast, ) from weakref import WeakValueDictionary @@ -58,7 +56,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): default: RunnableSerializable[Input, Output] """The default Runnable to use.""" - config: Optional[RunnableConfig] = None + config: RunnableConfig | None = None """The configuration to use.""" model_config = ConfigDict( @@ -92,28 +90,26 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return self.default.OutputType @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_input_schema(config) @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_output_schema(config) @override - def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: + def get_graph(self, config: RunnableConfig | None = None) -> Graph: runnable, config = self.prepare(config) return runnable.get_graph(config) @override def with_config( self, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, # Sadly Unpack is not well supported by mypy so this will have to be untyped **kwargs: Any, ) -> Runnable[Input, Output]: @@ -122,7 +118,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): ) def prepare( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> tuple[Runnable[Input, Output], RunnableConfig]: """Prepare the Runnable for invocation. @@ -140,19 +136,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): @abstractmethod def _prepare( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> tuple[Runnable[Input, Output], RunnableConfig]: ... @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: runnable, config = self.prepare(config) return runnable.invoke(input, config, **kwargs) @override async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: runnable, config = self.prepare(config) return await runnable.ainvoke(input, config, **kwargs) @@ -161,10 +157,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: configs = get_config_list(config, len(inputs)) prepared = [self.prepare(c) for c in configs] @@ -183,7 +179,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def invoke( prepared: tuple[Runnable[Input, Output], RunnableConfig], input_: Input, - ) -> Union[Output, Exception]: + ) -> Output | Exception: bound, config = prepared if return_exceptions: try: @@ -204,10 +200,10 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: configs = get_config_list(config, len(inputs)) prepared = [self.prepare(c) for c in configs] @@ -226,7 +222,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def ainvoke( prepared: tuple[Runnable[Input, Output], RunnableConfig], input_: Input, - ) -> Union[Output, Exception]: + ) -> Output | Exception: bound, config = prepared if return_exceptions: try: @@ -243,8 +239,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: runnable, config = self.prepare(config) return runnable.stream(input, config, **kwargs) @@ -253,8 +249,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: runnable, config = self.prepare(config) async for chunk in runnable.astream(input, config, **kwargs): @@ -264,8 +260,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def transform( self, input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: runnable, config = self.prepare(config) return runnable.transform(input, config, **kwargs) @@ -274,8 +270,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def atransform( self, input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: runnable, config = self.prepare(config) async for chunk in runnable.atransform(input, config, **kwargs): @@ -423,7 +419,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): return self.default.configurable_fields(**{**self.fields, **kwargs}) def _prepare( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> tuple[Runnable[Input, Output], RunnableConfig]: config = ensure_config(config) specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} @@ -470,9 +466,7 @@ class StrEnum(str, enum.Enum): _enums_for_spec: WeakValueDictionary[ - Union[ - ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField - ], + ConfigurableFieldSingleOption | ConfigurableFieldMultiOption | ConfigurableField, type[StrEnum], ] = WeakValueDictionary() @@ -542,7 +536,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): alternatives: dict[ str, - Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], + Runnable[Input, Output] | Callable[[], Runnable[Input, Output]], ] """The alternatives to choose from.""" @@ -616,7 +610,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ) def _prepare( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> tuple[Runnable[Input, Output], RunnableConfig]: config = ensure_config(config) which = config.get("configurable", {}).get(self.which.id, self.default_key) @@ -679,8 +673,8 @@ def prefix_config_spec( def make_options_spec( - spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption], - description: Optional[str], + spec: ConfigurableFieldSingleOption | ConfigurableFieldMultiOption, + description: str | None, ) -> ConfigurableFieldSpec: """Make options spec. diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 7adfdf954a3..7bb97745447 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -5,7 +5,7 @@ import inspect import typing from collections.abc import AsyncIterator, Iterator, Sequence from functools import wraps -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel, ConfigDict from typing_extensions import override @@ -95,7 +95,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): Any exception that is not a subclass of these exceptions will be raised immediately. """ - exception_key: Optional[str] = None + exception_key: str | None = None """If string is specified then handled exceptions will be passed to fallbacks as part of the input under the specified key. If None, exceptions will not be passed to fallbacks. If used, the base Runnable and its fallbacks @@ -116,14 +116,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return self.runnable.OutputType @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: return self.runnable.get_input_schema(config) @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: return self.runnable.get_output_schema(config) @@ -164,7 +162,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: if self.exception_key is not None and not isinstance(input, dict): msg = ( @@ -216,8 +214,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): async def ainvoke( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: if self.exception_key is not None and not isinstance(input, dict): msg = ( @@ -266,10 +264,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if self.exception_key is not None and not all( isinstance(input_, dict) for input_ in inputs @@ -305,7 +303,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input_, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip( + callback_managers, inputs, configs, strict=False + ) ] to_return: dict[int, Any] = {} @@ -323,7 +323,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): return_exceptions=True, **kwargs, ) - for (i, input_), output in zip(sorted(run_again.copy().items()), outputs): + for (i, input_), output in zip( + sorted(run_again.copy().items()), outputs, strict=False + ): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): @@ -358,10 +360,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if self.exception_key is not None and not all( isinstance(input_, dict) for input_ in inputs @@ -398,11 +400,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) - for cm, input_, config in zip(callback_managers, inputs, configs) + for cm, input_, config in zip( + callback_managers, inputs, configs, strict=False + ) ) ) - to_return: dict[int, Union[Output, BaseException]] = {} + to_return: dict[int, Output | BaseException] = {} run_again = dict(enumerate(inputs)) handled_exceptions: dict[int, BaseException] = {} first_to_raise = None @@ -418,7 +422,9 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): **kwargs, ) - for (i, input_), output in zip(sorted(run_again.copy().items()), outputs): + for (i, input_), output in zip( + sorted(run_again.copy().items()), outputs, strict=False + ): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): @@ -458,8 +464,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: if self.exception_key is not None and not isinstance(input, dict): msg = ( @@ -505,7 +511,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): raise first_error yield chunk - output: Optional[Output] = chunk + output: Output | None = chunk try: for chunk in stream: yield chunk @@ -522,8 +528,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: if self.exception_key is not None and not isinstance(input, dict): msg = ( @@ -569,7 +575,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): raise first_error yield chunk - output: Optional[Output] = chunk + output: Output | None = chunk try: async for chunk in stream: yield chunk diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index db49a713238..50e2535d83d 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -4,17 +4,15 @@ from __future__ import annotations import inspect from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum from typing import ( TYPE_CHECKING, Any, - Callable, NamedTuple, - Optional, Protocol, TypedDict, - Union, overload, ) from uuid import UUID, uuid4 @@ -70,14 +68,12 @@ class Edge(NamedTuple): """The source node id.""" target: str """The target node id.""" - data: Optional[Stringifiable] = None + data: Stringifiable | None = None """Optional data associated with the edge. Defaults to None.""" conditional: bool = False """Whether the edge is conditional. Defaults to False.""" - def copy( - self, *, source: Optional[str] = None, target: Optional[str] = None - ) -> Edge: + def copy(self, *, source: str | None = None, target: str | None = None) -> Edge: """Return a copy of the edge with optional new source and target nodes. Args: @@ -102,16 +98,16 @@ class Node(NamedTuple): """The unique identifier of the node.""" name: str """The name of the node.""" - data: Union[type[BaseModel], RunnableType, None] + data: type[BaseModel] | RunnableType | None """The data of the node.""" - metadata: Optional[dict[str, Any]] + metadata: dict[str, Any] | None """Optional metadata for the node. Defaults to None.""" def copy( self, *, - id: Optional[str] = None, - name: Optional[str] = None, + id: str | None = None, + name: str | None = None, ) -> Node: """Return a copy of the node with optional new id and name. @@ -135,7 +131,7 @@ class Branch(NamedTuple): condition: Callable[..., str] """A callable that returns a string representation of the condition.""" - ends: Optional[dict[str, str]] + ends: dict[str, str] | None """Optional dictionary of end node ids for the branches. Defaults to None.""" @@ -182,7 +178,7 @@ class MermaidDrawMethod(Enum): def node_data_str( id: str, - data: Union[type[BaseModel], RunnableType, None], + data: type[BaseModel] | RunnableType | None, ) -> str: """Convert the data of a node to a string. @@ -201,7 +197,7 @@ def node_data_str( def node_data_json( node: Node, *, with_schemas: bool = False -) -> dict[str, Union[str, dict[str, Any]]]: +) -> dict[str, str | dict[str, Any]]: """Convert the data of a node to a JSON-serializable format. Args: @@ -316,10 +312,10 @@ class Graph: def add_node( self, - data: Union[type[BaseModel], RunnableType, None], - id: Optional[str] = None, + data: type[BaseModel] | RunnableType | None, + id: str | None = None, *, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> Node: """Add a node to the graph and return it. @@ -357,7 +353,7 @@ class Graph: self, source: Node, target: Node, - data: Optional[Stringifiable] = None, + data: Stringifiable | None = None, conditional: bool = False, # noqa: FBT001,FBT002 ) -> Edge: """Add an edge to the graph and return it. @@ -388,7 +384,7 @@ class Graph: def extend( self, graph: Graph, *, prefix: str = "" - ) -> tuple[Optional[Node], Optional[Node]]: + ) -> tuple[Node | None, Node | None]: """Add all nodes and edges from another graph. Note this doesn't check for duplicates, nor does it connect the graphs. @@ -459,7 +455,7 @@ class Graph: ], ) - def first_node(self) -> Optional[Node]: + def first_node(self) -> Node | None: """Find the single node that is not a target of any edge. If there is no such node, or there are multiple, return None. @@ -471,7 +467,7 @@ class Graph: """ return _first_node(self) - def last_node(self) -> Optional[Node]: + def last_node(self) -> Node | None: """Find the single node that is not a source of any edge. If there is no such node, or there are multiple, return None. @@ -531,24 +527,24 @@ class Graph: def draw_png( self, output_file_path: str, - fontname: Optional[str] = None, - labels: Optional[LabelsDict] = None, + fontname: str | None = None, + labels: LabelsDict | None = None, ) -> None: ... @overload def draw_png( self, output_file_path: None, - fontname: Optional[str] = None, - labels: Optional[LabelsDict] = None, + fontname: str | None = None, + labels: LabelsDict | None = None, ) -> bytes: ... def draw_png( self, - output_file_path: Optional[str] = None, - fontname: Optional[str] = None, - labels: Optional[LabelsDict] = None, - ) -> Union[bytes, None]: + output_file_path: str | None = None, + fontname: str | None = None, + labels: LabelsDict | None = None, + ) -> bytes | None: """Draw the graph as a PNG image. Args: @@ -581,9 +577,9 @@ class Graph: *, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, - node_colors: Optional[NodeStyles] = None, + node_colors: NodeStyles | None = None, wrap_label_n_words: int = 9, - frontmatter_config: Optional[dict[str, Any]] = None, + frontmatter_config: dict[str, Any] | None = None, ) -> str: """Draw the graph as a Mermaid syntax string. @@ -636,16 +632,16 @@ class Graph: self, *, curve_style: CurveStyle = CurveStyle.LINEAR, - node_colors: Optional[NodeStyles] = None, + node_colors: NodeStyles | None = None, wrap_label_n_words: int = 9, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: str = "white", padding: int = 10, max_retries: int = 1, retry_delay: float = 1.0, - frontmatter_config: Optional[dict[str, Any]] = None, - base_url: Optional[str] = None, + frontmatter_config: dict[str, Any] | None = None, + base_url: str | None = None, ) -> bytes: """Draw the graph as a PNG image using Mermaid. @@ -711,7 +707,7 @@ class Graph: ) -def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: +def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Node | None: """Find the single node that is not a target of any edge. Exclude nodes/sources with ids in the exclude list. @@ -727,7 +723,7 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: return found[0] if len(found) == 1 else None -def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: +def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Node | None: """Find the single node that is not a source of any edge. Exclude nodes/targets with ids in the exclude list. diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 8b8b9f58a2f..ac70dee638b 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -10,7 +10,7 @@ import string import time from dataclasses import asdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal import yaml @@ -45,13 +45,13 @@ def draw_mermaid( nodes: dict[str, Node], edges: list[Edge], *, - first_node: Optional[str] = None, - last_node: Optional[str] = None, + first_node: str | None = None, + last_node: str | None = None, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, - node_styles: Optional[NodeStyles] = None, + node_styles: NodeStyles | None = None, wrap_label_n_words: int = 9, - frontmatter_config: Optional[dict[str, Any]] = None, + frontmatter_config: dict[str, Any] | None = None, ) -> str: """Draws a Mermaid graph using the provided graph data. @@ -163,7 +163,7 @@ def draw_mermaid( src_parts = edge.source.split(":") tgt_parts = edge.target.split(":") common_prefix = ":".join( - src for src, tgt in zip(src_parts, tgt_parts) if src == tgt + src for src, tgt in zip(src_parts, tgt_parts, strict=False) if src == tgt ) edge_groups.setdefault(common_prefix, []).append(edge) @@ -279,13 +279,13 @@ def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str: def draw_mermaid_png( mermaid_syntax: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, - background_color: Optional[str] = "white", + background_color: str | None = "white", padding: int = 10, max_retries: int = 1, retry_delay: float = 1.0, - base_url: Optional[str] = None, + base_url: str | None = None, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. @@ -339,8 +339,8 @@ def draw_mermaid_png( async def _render_mermaid_using_pyppeteer( mermaid_syntax: str, - output_file_path: Optional[str] = None, - background_color: Optional[str] = "white", + output_file_path: str | None = None, + background_color: str | None = "white", padding: int = 10, device_scale_factor: int = 3, ) -> bytes: @@ -411,12 +411,12 @@ async def _render_mermaid_using_pyppeteer( def _render_mermaid_using_api( mermaid_syntax: str, *, - output_file_path: Optional[str] = None, - background_color: Optional[str] = "white", - file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", + output_file_path: str | None = None, + background_color: str | None = "white", + file_type: Literal["jpeg", "png", "webp"] | None = "png", max_retries: int = 1, retry_delay: float = 1.0, - base_url: Optional[str] = None, + base_url: str | None = None, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" # Defaults to using the public mermaid.ink server. diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 3e335b2004c..c23ae6a9e5e 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -1,6 +1,6 @@ """Helper class to draw a state graph into a PNG file.""" -from typing import Any, Optional +from typing import Any from langchain_core.runnables.graph import Graph, LabelsDict @@ -25,7 +25,7 @@ class PngDrawer: """ def __init__( - self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None + self, fontname: str | None = None, labels: LabelsDict | None = None ) -> None: """Initializes the PNG drawer. @@ -95,7 +95,7 @@ class PngDrawer: viz: Any, source: str, target: str, - label: Optional[str] = None, + label: str | None = None, conditional: bool = False, # noqa: FBT001,FBT002 ) -> None: """Adds an edge to the graph. @@ -116,7 +116,7 @@ class PngDrawer: style="dotted" if conditional else "solid", ) - def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: + def draw(self, graph: Graph, output_path: str | None = None) -> bytes | None: """Draw the given state graph into a PNG file. Requires `graphviz` and `pygraphviz` to be installed. diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 9cbd25d5b6d..439050d3927 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -3,14 +3,11 @@ from __future__ import annotations import inspect -from collections.abc import Sequence +from collections.abc import Callable, Sequence from types import GenericAlias from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, ) from pydantic import BaseModel @@ -34,7 +31,7 @@ if TYPE_CHECKING: from langchain_core.tracers.schemas import Run -MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]] +MessagesOrDictWithMessages = Sequence["BaseMessage"] | dict[str, Any] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] @@ -229,13 +226,13 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] """Function that returns a new BaseChatMessageHistory. This function should either take a single positional argument ``session_id`` of type string and return a corresponding chat message history instance""" - input_messages_key: Optional[str] = None + input_messages_key: str | None = None """Must be specified if the base runnable accepts a dict as input. The key in the input dict that contains the messages.""" - output_messages_key: Optional[str] = None + output_messages_key: str | None = None """Must be specified if the base Runnable returns a dict as output. The key in the output dict that contains the messages.""" - history_messages_key: Optional[str] = None + history_messages_key: str | None = None """Must be specified if the base runnable accepts a dict as input and expects a separate key for historical messages.""" history_factory_config: Sequence[ConfigurableFieldSpec] @@ -244,23 +241,17 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] def __init__( self, - runnable: Union[ - Runnable[ - list[BaseMessage], - Union[str, BaseMessage, MessagesOrDictWithMessages], - ], - Runnable[ - dict[str, Any], - Union[str, BaseMessage, MessagesOrDictWithMessages], - ], - LanguageModelLike, - ], + runnable: Runnable[ + list[BaseMessage], str | BaseMessage | MessagesOrDictWithMessages + ] + | Runnable[dict[str, Any], str | BaseMessage | MessagesOrDictWithMessages] + | LanguageModelLike, get_session_history: GetSessionHistoryCallable, *, - input_messages_key: Optional[str] = None, - output_messages_key: Optional[str] = None, - history_messages_key: Optional[str] = None, - history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None, + input_messages_key: str | None = None, + output_messages_key: str | None = None, + history_messages_key: str | None = None, + history_factory_config: Sequence[ConfigurableFieldSpec] | None = None, **kwargs: Any, ) -> None: """Initialize RunnableWithMessageHistory. @@ -379,13 +370,11 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] ) @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: fields: dict = {} if self.input_messages_key and self.history_messages_key: fields[self.input_messages_key] = ( - Union[str, BaseMessage, Sequence[BaseMessage]], + str | BaseMessage | Sequence[BaseMessage], ..., ) elif self.input_messages_key: @@ -409,7 +398,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: """Get a pydantic model that can be used to validate output to the Runnable. @@ -441,7 +430,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] ) def _get_input_messages( - self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] + self, input_val: str | BaseMessage | Sequence[BaseMessage] | dict ) -> list[BaseMessage]: # If dictionary, try to pluck the single key representing messages if isinstance(input_val, dict): @@ -479,7 +468,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] raise ValueError(msg) def _get_output_messages( - self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] + self, output_val: str | BaseMessage | Sequence[BaseMessage] | dict ) -> list[BaseMessage]: # If dictionary, try to pluck the single key representing messages if isinstance(output_val, dict): @@ -569,7 +558,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] output_messages = self._get_output_messages(output_val) await hist.aadd_messages(input_messages + output_messages) - def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: + def _merge_configs(self, *configs: RunnableConfig | None) -> RunnableConfig: config = super()._merge_configs(*configs) expected_keys = [field_spec.id for field_spec in self.history_factory_config] diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index aef12a5f8f0..d5320a2f7cc 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -5,13 +5,10 @@ from __future__ import annotations import asyncio import inspect import threading -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, cast, ) @@ -135,18 +132,17 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ``` """ - input_type: Optional[type[Other]] = None + input_type: type[Other] | None = None - func: Optional[ - Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] - ] = None + func: Callable[[Other], None] | Callable[[Other, RunnableConfig], None] | None = ( + None + ) - afunc: Optional[ - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ] - ] = None + afunc: ( + Callable[[Other], Awaitable[None]] + | Callable[[Other, RunnableConfig], Awaitable[None]] + | None + ) = None @override def __repr_args__(self) -> Any: @@ -156,23 +152,16 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): def __init__( self, - func: Optional[ - Union[ - Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]], - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ], - ] - ] = None, - afunc: Optional[ - Union[ - Callable[[Other], Awaitable[None]], - Callable[[Other, RunnableConfig], Awaitable[None]], - ] - ] = None, + func: Callable[[Other], None] + | Callable[[Other, RunnableConfig], None] + | Callable[[Other], Awaitable[None]] + | Callable[[Other, RunnableConfig], Awaitable[None]] + | None = None, + afunc: Callable[[Other], Awaitable[None]] + | Callable[[Other, RunnableConfig], Awaitable[None]] + | None = None, *, - input_type: Optional[type[Other]] = None, + input_type: type[Other] | None = None, **kwargs: Any, ) -> None: """Create e RunnablePassthrough. @@ -217,14 +206,9 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): @override def assign( cls, - **kwargs: Union[ - Runnable[dict[str, Any], Any], - Callable[[dict[str, Any]], Any], - Mapping[ - str, - Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]], - ], - ], + **kwargs: Runnable[dict[str, Any], Any] + | Callable[[dict[str, Any]], Any] + | Mapping[str, Runnable[dict[str, Any], Any] | Callable[[dict[str, Any]], Any]], ) -> RunnableAssign: """Merge the Dict input with the output produced by the mapping argument. @@ -240,7 +224,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): @override def invoke( - self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Other, config: RunnableConfig | None = None, **kwargs: Any ) -> Other: if self.func is not None: call_func_with_variable_args( @@ -252,8 +236,8 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): async def ainvoke( self, input: Other, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Other: if self.afunc is not None: await acall_func_with_variable_args( @@ -269,7 +253,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): def transform( self, input: Iterator[Other], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[Other]: if self.func is None: @@ -300,7 +284,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): async def atransform( self, input: AsyncIterator[Other], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[Other]: if self.afunc is None and self.func is None: @@ -343,7 +327,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): def stream( self, input: Other, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[Other]: return self.transform(iter([input]), config, **kwargs) @@ -352,7 +336,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): async def astream( self, input: Other, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[Other]: async def input_aiter() -> AsyncIterator[Other]: @@ -433,9 +417,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): return ["langchain", "schema", "runnable"] @override - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: name = ( name or self.name @@ -444,9 +426,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): return super().get_name(suffix, name=name) @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: map_input_schema = self.mapper.get_input_schema(config) if not issubclass(map_input_schema, RootModel): # ie. it's a dict @@ -456,7 +436,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): @override def get_output_schema( - self, config: Optional[RunnableConfig] = None + self, config: RunnableConfig | None = None ) -> type[BaseModel]: map_input_schema = self.mapper.get_input_schema(config) map_output_schema = self.mapper.get_output_schema(config) @@ -521,7 +501,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): def invoke( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> dict[str, Any]: return self._call_with_config(self._invoke, input, config, **kwargs) @@ -550,7 +530,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def ainvoke( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> dict[str, Any]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) @@ -605,7 +585,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): def transform( self, input: Iterator[dict[str, Any]], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any | None, ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( @@ -657,7 +637,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def atransform( self, input: AsyncIterator[dict[str, Any]], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( @@ -669,7 +649,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): def stream( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) @@ -678,7 +658,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def astream( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: async def input_aiter() -> AsyncIterator[dict[str, Any]]: @@ -715,9 +695,9 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): ``` """ - keys: Union[str, list[str]] + keys: str | list[str] - def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None: + def __init__(self, keys: str | list[str], **kwargs: Any) -> None: """Create a RunnablePick. Args: @@ -742,9 +722,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): return ["langchain", "schema", "runnable"] @override - def get_name( - self, suffix: Optional[str] = None, *, name: Optional[str] = None - ) -> str: + def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str: name = ( name or self.name @@ -775,7 +753,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): def invoke( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> dict[str, Any]: return self._call_with_config(self._invoke, input, config, **kwargs) @@ -790,7 +768,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def ainvoke( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> dict[str, Any]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) @@ -808,7 +786,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): def transform( self, input: Iterator[dict[str, Any]], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( @@ -828,7 +806,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def atransform( self, input: AsyncIterator[dict[str, Any]], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( @@ -840,7 +818,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): def stream( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) @@ -849,7 +827,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): async def astream( self, input: dict[str, Any], - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> AsyncIterator[dict[str, Any]]: async def input_aiter() -> AsyncIterator[dict[str, Any]]: diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 95fdc4ca07c..7fe626f8a3d 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -3,9 +3,7 @@ from typing import ( TYPE_CHECKING, Any, - Optional, TypeVar, - Union, cast, ) @@ -126,7 +124,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede wait_exponential_jitter: bool = True """Whether to add jitter to the exponential backoff.""" - exponential_jitter_params: Optional[ExponentialJitterParams] = None + exponential_jitter_params: ExponentialJitterParams | None = None """Parameters for ``tenacity.wait_exponential_jitter``. Namely: ``initial``, ``max``, ``exp_base``, and ``jitter`` (all float values). """ @@ -174,7 +172,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede retry_state: RetryCallState, ) -> list[RunnableConfig]: return [ - self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) + self._patch_config(c, rm, retry_state) + for c, rm in zip(config, run_manager, strict=False) ] def _invoke( @@ -197,7 +196,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: return self._call_with_config(self._invoke, input, config, **kwargs) @@ -221,7 +220,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede @override async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) @@ -231,7 +230,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede run_manager: list["CallbackManagerForChainRun"], config: list[RunnableConfig], **kwargs: Any, - ) -> list[Union[Output, Exception]]: + ) -> list[Output | Exception]: results_map: dict[int, Output] = {} not_set: list[Output] = [] @@ -280,7 +279,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede if result is not_set: result = cast("list[Output]", [e] * len(inputs)) - outputs: list[Union[Output, Exception]] = [] + outputs: list[Output | Exception] = [] for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) @@ -292,7 +291,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede def batch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -307,7 +306,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede run_manager: list["AsyncCallbackManagerForChainRun"], config: list[RunnableConfig], **kwargs: Any, - ) -> list[Union[Output, Exception]]: + ) -> list[Output | Exception]: results_map: dict[int, Output] = {} not_set: list[Output] = [] @@ -355,7 +354,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede if result is not_set: result = cast("list[Output]", [e] * len(inputs)) - outputs: list[Union[Output, Exception]] = [] + outputs: list[Output | Exception] = [] for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) @@ -367,7 +366,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede async def abatch( self, inputs: list[Input], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 035e21bb7f9..0bffd159836 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -2,13 +2,10 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, cast, ) @@ -75,7 +72,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): def __init__( self, - runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]], + runnables: Mapping[str, Runnable[Any, Output] | Callable[[Any], Output]], ) -> None: """Create a RouterRunnable. @@ -108,7 +105,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): @override def invoke( - self, input: RouterInput, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: RouterInput, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: key = input["key"] actual_input = input["input"] @@ -123,8 +120,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): async def ainvoke( self, input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Output: key = input["key"] actual_input = input["input"] @@ -139,10 +136,10 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): def batch( self, inputs: list[RouterInput], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if not inputs: return [] @@ -155,7 +152,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): def invoke( runnable: Runnable, input_: Input, config: RunnableConfig - ) -> Union[Output, Exception]: + ) -> Output | Exception: if return_exceptions: try: return runnable.invoke(input_, config, **kwargs) @@ -176,10 +173,10 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): async def abatch( self, inputs: list[RouterInput], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> list[Output]: if not inputs: return [] @@ -192,7 +189,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): async def ainvoke( runnable: Runnable, input_: Input, config: RunnableConfig - ) -> Union[Output, Exception]: + ) -> Output | Exception: if return_exceptions: try: return await runnable.ainvoke(input_, config, **kwargs) @@ -212,8 +209,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): def stream( self, input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: key = input["key"] actual_input = input["input"] @@ -228,8 +225,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): async def astream( self, input: RouterInput, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: key = input["key"] actual_input = input["input"] diff --git a/libs/core/langchain_core/runnables/schema.py b/libs/core/langchain_core/runnables/schema.py index 07ee905a7ef..828085fd435 100644 --- a/libs/core/langchain_core/runnables/schema.py +++ b/libs/core/langchain_core/runnables/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal from typing_extensions import NotRequired, TypedDict @@ -182,4 +182,4 @@ class CustomStreamEvent(BaseStreamEvent): """The data associated with the event. Free form and can be anything.""" -StreamEvent = Union[StandardStreamEvent, CustomStreamEvent] +StreamEvent = StandardStreamEvent | CustomStreamEvent diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 73ed64a89f6..1425239a589 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -6,7 +6,7 @@ import ast import asyncio import inspect import textwrap -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from contextvars import Context from functools import lru_cache from inspect import signature @@ -14,13 +14,10 @@ from itertools import groupby from typing import ( TYPE_CHECKING, Any, - Callable, NamedTuple, - Optional, Protocol, TypeGuard, TypeVar, - Union, ) from typing_extensions import override @@ -58,7 +55,7 @@ async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: return await coro -async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: +async def gather_with_concurrency(n: int | None, *coros: Coroutine) -> list: """Gather coroutines with a limit on the number of concurrent coroutines. Args: @@ -344,7 +341,7 @@ class GetLambdaSource(ast.NodeVisitor): def __init__(self) -> None: """Initialize the visitor.""" - self.source: Optional[str] = None + self.source: str | None = None self.count = 0 @override @@ -359,7 +356,7 @@ class GetLambdaSource(ast.NodeVisitor): self.source = ast.unparse(node) -def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]: +def get_function_first_arg_dict_keys(func: Callable) -> list[str] | None: """Get the keys of the first argument of a function if it is a dict. Args: @@ -379,7 +376,7 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]: return None -def get_lambda_source(func: Callable) -> Optional[str]: +def get_lambda_source(func: Callable) -> str | None: """Get the source code of a lambda function. Args: @@ -521,7 +518,7 @@ class SupportsAdd(Protocol[_T_contra, _T_co]): Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any]) -def add(addables: Iterable[Addable]) -> Optional[Addable]: +def add(addables: Iterable[Addable]) -> Addable | None: """Add a sequence of addable objects together. Args: @@ -530,13 +527,13 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]: Returns: Optional[Addable]: The result of adding the addable objects. """ - final: Optional[Addable] = None + final: Addable | None = None for chunk in addables: final = chunk if final is None else final + chunk return final -async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: +async def aadd(addables: AsyncIterable[Addable]) -> Addable | None: """Asynchronously add a sequence of addable objects together. Args: @@ -545,7 +542,7 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: Returns: Optional[Addable]: The result of adding the addable objects. """ - final: Optional[Addable] = None + final: Addable | None = None async for chunk in addables: final = chunk if final is None else final + chunk return final @@ -556,11 +553,11 @@ class ConfigurableField(NamedTuple): id: str """The unique identifier of the field.""" - name: Optional[str] = None + name: str | None = None """The name of the field. Defaults to None.""" - description: Optional[str] = None + description: str | None = None """The description of the field. Defaults to None.""" - annotation: Optional[Any] = None + annotation: Any | None = None """The annotation of the field. Defaults to None.""" is_shared: bool = False """Whether the field is shared. Defaults to False.""" @@ -579,9 +576,9 @@ class ConfigurableFieldSingleOption(NamedTuple): """The options for the field.""" default: str """The default value for the field.""" - name: Optional[str] = None + name: str | None = None """The name of the field. Defaults to None.""" - description: Optional[str] = None + description: str | None = None """The description of the field. Defaults to None.""" is_shared: bool = False """Whether the field is shared. Defaults to False.""" @@ -600,9 +597,9 @@ class ConfigurableFieldMultiOption(NamedTuple): """The options for the field.""" default: Sequence[str] """The default values for the field.""" - name: Optional[str] = None + name: str | None = None """The name of the field. Defaults to None.""" - description: Optional[str] = None + description: str | None = None """The description of the field. Defaults to None.""" is_shared: bool = False """Whether the field is shared. Defaults to False.""" @@ -612,9 +609,9 @@ class ConfigurableFieldMultiOption(NamedTuple): return hash((self.id, tuple(self.options.keys()), tuple(self.default))) -AnyConfigurableField = Union[ - ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption -] +AnyConfigurableField = ( + ConfigurableField | ConfigurableFieldSingleOption | ConfigurableFieldMultiOption +) class ConfigurableFieldSpec(NamedTuple): @@ -624,15 +621,15 @@ class ConfigurableFieldSpec(NamedTuple): """The unique identifier of the field.""" annotation: Any """The annotation of the field.""" - name: Optional[str] = None + name: str | None = None """The name of the field. Defaults to None.""" - description: Optional[str] = None + description: str | None = None """The description of the field. Defaults to None.""" default: Any = None """The default value for the field. Defaults to None.""" is_shared: bool = False """Whether the field is shared. Defaults to False.""" - dependencies: Optional[list[str]] = None + dependencies: list[str] | None = None """The dependencies of the field. Defaults to None.""" @@ -672,12 +669,12 @@ class _RootEventFilter: def __init__( self, *, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, ) -> None: """Utility to filter the root event in the astream_events implementation. diff --git a/libs/core/langchain_core/stores.py b/libs/core/langchain_core/stores.py index df9a78f9b48..0e5a25e6e6f 100644 --- a/libs/core/langchain_core/stores.py +++ b/libs/core/langchain_core/stores.py @@ -11,9 +11,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( Any, Generic, - Optional, TypeVar, - Union, ) from typing_extensions import override @@ -82,7 +80,7 @@ class BaseStore(ABC, Generic[K, V]): """ @abstractmethod - def mget(self, keys: Sequence[K]) -> list[Optional[V]]: + def mget(self, keys: Sequence[K]) -> list[V | None]: """Get the values associated with the given keys. Args: @@ -93,7 +91,7 @@ class BaseStore(ABC, Generic[K, V]): If a key is not found, the corresponding value will be None. """ - async def amget(self, keys: Sequence[K]) -> list[Optional[V]]: + async def amget(self, keys: Sequence[K]) -> list[V | None]: """Async get the values associated with the given keys. Args: @@ -138,9 +136,7 @@ class BaseStore(ABC, Generic[K, V]): return await run_in_executor(None, self.mdelete, keys) @abstractmethod - def yield_keys( - self, *, prefix: Optional[str] = None - ) -> Union[Iterator[K], Iterator[str]]: + def yield_keys(self, *, prefix: str | None = None) -> Iterator[K] | Iterator[str]: """Get an iterator over keys that match the given prefix. Args: @@ -153,8 +149,8 @@ class BaseStore(ABC, Generic[K, V]): """ async def ayield_keys( - self, *, prefix: Optional[str] = None - ) -> Union[AsyncIterator[K], AsyncIterator[str]]: + self, *, prefix: str | None = None + ) -> AsyncIterator[K] | AsyncIterator[str]: """Async get an iterator over keys that match the given prefix. Args: @@ -184,7 +180,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): """Initialize an empty store.""" self.store: dict[str, V] = {} - def mget(self, keys: Sequence[str]) -> list[Optional[V]]: + def mget(self, keys: Sequence[str]) -> list[V | None]: """Get the values associated with the given keys. Args: @@ -196,7 +192,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): """ return [self.store.get(key) for key in keys] - async def amget(self, keys: Sequence[str]) -> list[Optional[V]]: + async def amget(self, keys: Sequence[str]) -> list[V | None]: """Async get the values associated with the given keys. Args: @@ -235,7 +231,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): """ self.mdelete(keys) - def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: + def yield_keys(self, prefix: str | None = None) -> Iterator[str]: """Get an iterator over keys that match the given prefix. Args: @@ -251,7 +247,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): if key.startswith(prefix): yield key - async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]: + async def ayield_keys(self, prefix: str | None = None) -> AsyncIterator[str]: """Async get an async iterator over keys that match the given prefix. Args: diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index 61fb8f0e53a..664b85b50e2 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pydantic import BaseModel @@ -15,12 +15,12 @@ if TYPE_CHECKING: class Visitor(ABC): """Defines interface for IR translation using a visitor pattern.""" - allowed_comparators: Optional[Sequence[Comparator]] = None + allowed_comparators: Sequence[Comparator] | None = None """Allowed comparators for the visitor.""" - allowed_operators: Optional[Sequence[Operator]] = None + allowed_operators: Sequence[Operator] | None = None """Allowed operators for the visitor.""" - def _validate_func(self, func: Union[Operator, Comparator]) -> None: + def _validate_func(self, func: Operator | Comparator) -> None: if ( isinstance(func, Operator) and self.allowed_operators is not None @@ -174,16 +174,16 @@ class StructuredQuery(Expr): query: str """Query string.""" - filter: Optional[FilterDirective] + filter: FilterDirective | None """Filtering expression.""" - limit: Optional[int] + limit: int | None """Limit on the number of results.""" def __init__( self, query: str, - filter: Optional[FilterDirective], # noqa: A002 - limit: Optional[int] = None, + filter: FilterDirective | None, # noqa: A002 + limit: int | None = None, **kwargs: Any, ) -> None: """Create a StructuredQuery. diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 733c6ac6a62..d429b59cf29 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -8,16 +8,14 @@ import json import typing import warnings from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import signature from typing import ( TYPE_CHECKING, Annotated, Any, - Callable, Literal, - Optional, TypeVar, - Union, cast, get_args, get_origin, @@ -282,7 +280,7 @@ def create_schema_from_function( model_name: str, func: Callable, *, - filter_args: Optional[Sequence[str]] = None, + filter_args: Sequence[str] | None = None, parse_docstring: bool = False, error_on_invalid_docstring: bool = False, include_injected: bool = True, @@ -387,10 +385,10 @@ class ToolException(Exception): # noqa: N818 """ -ArgsSchema = Union[TypeBaseModel, dict[str, Any]] +ArgsSchema = TypeBaseModel | dict[str, Any] -class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): +class BaseTool(RunnableSerializable[str | dict | ToolCall, Any]): """Base class for all LangChain tools. This abstract class defines the interface that all LangChain tools must implement. @@ -438,7 +436,7 @@ class ChildTool(BaseTool): You can provide few-shot examples as a part of the description. """ - args_schema: Annotated[Optional[ArgsSchema], SkipValidation()] = Field( + args_schema: Annotated[ArgsSchema | None, SkipValidation()] = Field( default=None, description="The tool schema." ) """Pydantic model class to validate and parse the tool's input arguments. @@ -461,27 +459,25 @@ class ChildTool(BaseTool): callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - tags: Optional[list[str]] = None + tags: list[str] | None = None """Optional list of tags associated with the tool. Defaults to None. These tags will be associated with each call to this tool, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a tool with its use case. """ - metadata: Optional[dict[str, Any]] = None + metadata: dict[str, Any] | None = None """Optional metadata associated with the tool. Defaults to None. This metadata will be associated with each call to this tool, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a tool with its use case. """ - handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = ( - False - ) + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False """Handle the content of the ToolException thrown.""" - handle_validation_error: Optional[ - Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]] - ] = False + handle_validation_error: ( + bool | str | Callable[[ValidationError | ValidationErrorV1], str] | None + ) = False """Handle the content of the ValidationError thrown.""" response_format: Literal["content", "content_and_artifact"] = "content" @@ -570,9 +566,7 @@ class ChildTool(BaseTool): # --- Runnable --- @override - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> type[BaseModel]: + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """The tool's input schema. Args: @@ -590,8 +584,8 @@ class ChildTool(BaseTool): @override def invoke( self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, + input: str | dict | ToolCall, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Any: tool_input, kwargs = _prep_run_args(input, config, **kwargs) @@ -600,8 +594,8 @@ class ChildTool(BaseTool): @override async def ainvoke( self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, + input: str | dict | ToolCall, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Any: tool_input, kwargs = _prep_run_args(input, config, **kwargs) @@ -610,8 +604,8 @@ class ChildTool(BaseTool): # --- Tool --- def _parse_input( - self, tool_input: Union[str, dict], tool_call_id: Optional[str] - ) -> Union[str, dict[str, Any]]: + self, tool_input: str | dict, tool_call_id: str | None + ) -> str | dict[str, Any]: """Parse and validate tool input using the args schema. Args: @@ -715,7 +709,7 @@ class ChildTool(BaseTool): return await run_in_executor(None, self._run, *args, **kwargs) def _to_args_and_kwargs( - self, tool_input: Union[str, dict], tool_call_id: Optional[str] + self, tool_input: str | dict, tool_call_id: str | None ) -> tuple[tuple, dict]: """Convert tool input to positional and keyword arguments. @@ -755,18 +749,18 @@ class ChildTool(BaseTool): def run( self, - tool_input: Union[str, dict[str, Any]], - verbose: Optional[bool] = None, # noqa: FBT001 - start_color: Optional[str] = "green", - color: Optional[str] = "green", + tool_input: str | dict[str, Any], + verbose: bool | None = None, # noqa: FBT001 + start_color: str | None = "green", + color: str | None = "green", callbacks: Callbacks = None, *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, - config: Optional[RunnableConfig] = None, - tool_call_id: Optional[str] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + config: RunnableConfig | None = None, + tool_call_id: str | None = None, **kwargs: Any, ) -> Any: """Run the tool. @@ -818,7 +812,7 @@ class ChildTool(BaseTool): content = None artifact = None status = "success" - error_to_raise: Union[Exception, KeyboardInterrupt, None] = None + error_to_raise: Exception | KeyboardInterrupt | None = None try: child_config = patch_config(config, callbacks=run_manager.get_child()) with set_config_context(child_config) as context: @@ -867,18 +861,18 @@ class ChildTool(BaseTool): async def arun( self, - tool_input: Union[str, dict], - verbose: Optional[bool] = None, # noqa: FBT001 - start_color: Optional[str] = "green", - color: Optional[str] = "green", + tool_input: str | dict, + verbose: bool | None = None, # noqa: FBT001 + start_color: str | None = "green", + color: str | None = "green", callbacks: Callbacks = None, *, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - run_name: Optional[str] = None, - run_id: Optional[uuid.UUID] = None, - config: Optional[RunnableConfig] = None, - tool_call_id: Optional[str] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, + run_id: uuid.UUID | None = None, + config: RunnableConfig | None = None, + tool_call_id: str | None = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously. @@ -928,7 +922,7 @@ class ChildTool(BaseTool): content = None artifact = None status = "success" - error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None + error_to_raise: Exception | KeyboardInterrupt | None = None try: tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -993,11 +987,9 @@ def _is_tool_call(x: Any) -> bool: def _handle_validation_error( - e: Union[ValidationError, ValidationErrorV1], + e: ValidationError | ValidationErrorV1, *, - flag: Union[ - Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] - ], + flag: Literal[True] | str | Callable[[ValidationError | ValidationErrorV1], str], ) -> str: """Handle validation errors based on the configured flag. @@ -1029,7 +1021,7 @@ def _handle_validation_error( def _handle_tool_error( e: ToolException, *, - flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], + flag: Literal[True] | str | Callable[[ToolException], str] | None, ) -> str: """Handle tool execution errors based on the configured flag. @@ -1059,10 +1051,10 @@ def _handle_tool_error( def _prep_run_args( - value: Union[str, dict, ToolCall], - config: Optional[RunnableConfig], + value: str | dict | ToolCall, + config: RunnableConfig | None, **kwargs: Any, -) -> tuple[Union[str, dict], dict]: +) -> tuple[str | dict, dict]: """Prepare arguments for tool execution. Args: @@ -1075,11 +1067,11 @@ def _prep_run_args( """ config = ensure_config(config) if _is_tool_call(value): - tool_call_id: Optional[str] = cast("ToolCall", value)["id"] - tool_input: Union[str, dict] = cast("ToolCall", value)["args"].copy() + tool_call_id: str | None = cast("ToolCall", value)["id"] + tool_input: str | dict = cast("ToolCall", value)["args"].copy() else: tool_call_id = None - tool_input = cast("Union[str, dict]", value) + tool_input = cast("str | dict", value) return ( tool_input, dict( @@ -1098,10 +1090,10 @@ def _prep_run_args( def _format_output( content: Any, artifact: Any, - tool_call_id: Optional[str], + tool_call_id: str | None, name: str, status: str, -) -> Union[ToolOutputMixin, Any]: +) -> ToolOutputMixin | Any: """Format tool output as a ToolMessage if appropriate. Args: @@ -1176,7 +1168,7 @@ def _stringify(content: Any) -> str: return str(content) -def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: +def _get_type_hints(func: Callable) -> dict[str, type] | None: """Get type hints from a function, handling partial functions. Args: @@ -1193,7 +1185,7 @@ def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: return None -def _get_runnable_config_param(func: Callable) -> Optional[str]: +def _get_runnable_config_param(func: Callable) -> str | None: """Find the parameter name for RunnableConfig in a function. Args: @@ -1247,7 +1239,7 @@ class InjectedToolCallId(InjectedToolArg): def _is_injected_arg_type( - type_: Union[type, TypeVar], injected_type: Optional[type[InjectedToolArg]] = None + type_: type | TypeVar, injected_type: type[InjectedToolArg] | None = None ) -> bool: """Check if a type annotation indicates an injected argument. @@ -1267,8 +1259,8 @@ def _is_injected_arg_type( def get_all_basemodel_annotations( - cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True -) -> dict[str, Union[type, TypeVar]]: + cls: TypeBaseModel | Any, *, default_to_bound: bool = True +) -> dict[str, type | TypeVar]: """Get all annotations from a Pydantic BaseModel and its parents. Args: @@ -1283,7 +1275,7 @@ def get_all_basemodel_annotations( fields = get_fields(cls) alias_map = {field.alias: name for name, field in fields.items() if field.alias} - annotations: dict[str, Union[type, TypeVar]] = {} + annotations: dict[str, type | TypeVar] = {} for name, param in inspect.signature(cls).parameters.items(): # Exclude hidden init args added by pydantic Config. For example if # BaseModel(extra="allow") then "extra_data" will part of init sig. @@ -1324,7 +1316,7 @@ def get_all_basemodel_annotations( # generic_type_vars = (type vars in Baz) # generic_map = {type var in Baz: str} generic_type_vars: tuple = getattr(parent_origin, "__parameters__", ()) - generic_map = dict(zip(generic_type_vars, get_args(parent))) + generic_map = dict(zip(generic_type_vars, get_args(parent), strict=False)) for field in getattr(parent_origin, "__annotations__", {}): annotations[field] = _replace_type_vars( annotations[field], generic_map, default_to_bound=default_to_bound @@ -1337,11 +1329,11 @@ def get_all_basemodel_annotations( def _replace_type_vars( - type_: Union[type, TypeVar], - generic_map: Optional[dict[TypeVar, type]] = None, + type_: type | TypeVar, + generic_map: dict[TypeVar, type] | None = None, *, default_to_bound: bool = True, -) -> Union[type, TypeVar]: +) -> type | TypeVar: """Replace TypeVars in a type annotation with concrete types. Args: diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index ff4f03572f6..53c43acd568 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -1,7 +1,8 @@ """Convert functions and runnables to tools.""" import inspect -from typing import Any, Callable, Literal, Optional, Union, get_type_hints, overload +from collections.abc import Callable +from typing import Any, Literal, get_type_hints, overload from pydantic import BaseModel, Field, create_model @@ -15,14 +16,14 @@ from langchain_core.tools.structured import StructuredTool @overload def tool( *, - description: Optional[str] = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... +) -> Callable[[Callable | Runnable], BaseTool]: ... @overload @@ -30,9 +31,9 @@ def tool( name_or_callable: str, runnable: Runnable, *, - description: Optional[str] = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, @@ -44,9 +45,9 @@ def tool( def tool( name_or_callable: Callable, *, - description: Optional[str] = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, @@ -58,31 +59,28 @@ def tool( def tool( name_or_callable: str, *, - description: Optional[str] = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Callable[[Union[Callable, Runnable]], BaseTool]: ... +) -> Callable[[Callable | Runnable], BaseTool]: ... def tool( - name_or_callable: Optional[Union[str, Callable]] = None, - runnable: Optional[Runnable] = None, + name_or_callable: str | Callable | None = None, + runnable: Runnable | None = None, *args: Any, - description: Optional[str] = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, -) -> Union[ - BaseTool, - Callable[[Union[Callable, Runnable]], BaseTool], -]: +) -> BaseTool | Callable[[Callable | Runnable], BaseTool]: """Make tools out of functions, can be used with or without arguments. Args: @@ -231,7 +229,7 @@ def tool( def _create_tool_factory( tool_name: str, - ) -> Callable[[Union[Callable, Runnable]], BaseTool]: + ) -> Callable[[Callable | Runnable], BaseTool]: """Create a decorator that takes a callable and returns a tool. Args: @@ -241,7 +239,7 @@ def tool( A function that takes a callable or Runnable and returns a tool. """ - def _tool_factory(dec_func: Union[Callable, Runnable]) -> BaseTool: + def _tool_factory(dec_func: Callable | Runnable) -> BaseTool: tool_description = description if isinstance(dec_func, Runnable): runnable = dec_func @@ -251,18 +249,18 @@ def tool( raise ValueError(msg) async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any + callbacks: Callbacks | None = None, **kwargs: Any ) -> Any: return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) def invoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any + callbacks: Callbacks | None = None, **kwargs: Any ) -> Any: return runnable.invoke(kwargs, {"callbacks": callbacks}) coroutine = ainvoke_wrapper func = invoke_wrapper - schema: Optional[ArgsSchema] = runnable.input_schema + schema: ArgsSchema | None = runnable.input_schema tool_description = description or repr(runnable) elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func @@ -352,7 +350,7 @@ def tool( # @tool(parse_docstring=True) # def my_tool(): # pass - def _partial(func: Union[Callable, Runnable]) -> BaseTool: + def _partial(func: Callable | Runnable) -> BaseTool: """Partial function that takes a callable and returns a tool.""" name_ = func.get_name() if isinstance(func, Runnable) else func.__name__ tool_factory = _create_tool_factory(name_) @@ -370,7 +368,7 @@ def _get_description_from_runnable(runnable: Runnable) -> str: def _get_schema_from_runnable_and_arg_types( runnable: Runnable, name: str, - arg_types: Optional[dict[str, type]] = None, + arg_types: dict[str, type] | None = None, ) -> type[BaseModel]: """Infer args_schema for tool.""" if arg_types is None: @@ -389,11 +387,11 @@ def _get_schema_from_runnable_and_arg_types( def convert_runnable_to_tool( runnable: Runnable, - args_schema: Optional[type[BaseModel]] = None, + args_schema: type[BaseModel] | None = None, *, - name: Optional[str] = None, - description: Optional[str] = None, - arg_types: Optional[dict[str, type]] = None, + name: str | None = None, + description: str | None = None, + arg_types: dict[str, type] | None = None, ) -> BaseTool: """Convert a Runnable into a BaseTool. @@ -421,12 +419,10 @@ def convert_runnable_to_tool( description=description, ) - async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: + async def ainvoke_wrapper(callbacks: Callbacks | None = None, **kwargs: Any) -> Any: return await runnable.ainvoke(kwargs, config={"callbacks": callbacks}) - def invoke_wrapper(callbacks: Optional[Callbacks] = None, **kwargs: Any) -> Any: + def invoke_wrapper(callbacks: Callbacks | None = None, **kwargs: Any) -> Any: return runnable.invoke(kwargs, config={"callbacks": callbacks}) if ( diff --git a/libs/core/langchain_core/tools/render.py b/libs/core/langchain_core/tools/render.py index eeb6f3d5c83..576dad93530 100644 --- a/libs/core/langchain_core/tools/render.py +++ b/libs/core/langchain_core/tools/render.py @@ -2,8 +2,8 @@ from __future__ import annotations +from collections.abc import Callable from inspect import signature -from typing import Callable from langchain_core.tools.base import BaseTool diff --git a/libs/core/langchain_core/tools/retriever.py b/libs/core/langchain_core/tools/retriever.py index 002fa5e80d6..e5ff95049fe 100644 --- a/libs/core/langchain_core/tools/retriever.py +++ b/libs/core/langchain_core/tools/retriever.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal from pydantic import BaseModel, Field @@ -34,7 +34,7 @@ def _get_relevant_documents( document_separator: str, callbacks: Callbacks = None, response_format: Literal["content", "content_and_artifact"] = "content", -) -> Union[str, tuple[str, list[Document]]]: +) -> str | tuple[str, list[Document]]: docs = retriever.invoke(query, config={"callbacks": callbacks}) content = document_separator.join( format_document(doc, document_prompt) for doc in docs @@ -52,7 +52,7 @@ async def _aget_relevant_documents( document_separator: str, callbacks: Callbacks = None, response_format: Literal["content", "content_and_artifact"] = "content", -) -> Union[str, tuple[str, list[Document]]]: +) -> str | tuple[str, list[Document]]: docs = await retriever.ainvoke(query, config={"callbacks": callbacks}) content = document_separator.join( [await aformat_document(doc, document_prompt) for doc in docs] @@ -69,7 +69,7 @@ def create_retriever_tool( name: str, description: str, *, - document_prompt: Optional[BasePromptTemplate] = None, + document_prompt: BasePromptTemplate | None = None, document_separator: str = "\n\n", response_format: Literal["content", "content_and_artifact"] = "content", ) -> Tool: diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index b115150e1b7..a1524b099fd 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -2,14 +2,11 @@ from __future__ import annotations -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from inspect import signature from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, ) from typing_extensions import override @@ -34,9 +31,9 @@ class Tool(BaseTool): """Tool that takes in function or coroutine directly.""" description: str = "" - func: Optional[Callable[..., str]] + func: Callable[..., str] | None """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[str]]] = None + coroutine: Callable[..., Awaitable[str]] | None = None """The asynchronous version of the function.""" # --- Runnable --- @@ -44,8 +41,8 @@ class Tool(BaseTool): @override async def ainvoke( self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, + input: str | dict | ToolCall, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Any: if not self.coroutine: @@ -70,7 +67,7 @@ class Tool(BaseTool): return {"tool_input": {"type": "string"}} def _to_args_and_kwargs( - self, tool_input: Union[str, dict], tool_call_id: Optional[str] + self, tool_input: str | dict, tool_call_id: str | None ) -> tuple[tuple, dict]: """Convert tool input to pydantic model. @@ -101,7 +98,7 @@ class Tool(BaseTool): self, *args: Any, config: RunnableConfig, - run_manager: Optional[CallbackManagerForToolRun] = None, + run_manager: CallbackManagerForToolRun | None = None, **kwargs: Any, ) -> Any: """Use the tool. @@ -128,7 +125,7 @@ class Tool(BaseTool): self, *args: Any, config: RunnableConfig, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + run_manager: AsyncCallbackManagerForToolRun | None = None, **kwargs: Any, ) -> Any: """Use the tool asynchronously. @@ -157,7 +154,7 @@ class Tool(BaseTool): # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Optional[Callable], description: str, **kwargs: Any + self, name: str, func: Callable | None, description: str, **kwargs: Any ) -> None: """Initialize tool.""" super().__init__(name=name, func=func, description=description, **kwargs) @@ -165,14 +162,13 @@ class Tool(BaseTool): @classmethod def from_function( cls, - func: Optional[Callable], + func: Callable | None, name: str, # We keep these required to support backwards compatibility description: str, return_direct: bool = False, # noqa: FBT001,FBT002 - args_schema: Optional[ArgsSchema] = None, - coroutine: Optional[ - Callable[..., Awaitable[Any]] - ] = None, # This is last for compatibility, but should be after func + args_schema: ArgsSchema | None = None, + coroutine: Callable[..., Awaitable[Any]] + | None = None, # This is last for compatibility, but should be after func **kwargs: Any, ) -> Tool: """Initialize tool from a function. diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index e512c3742f3..4beeb9f68ea 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -3,16 +3,13 @@ from __future__ import annotations import textwrap -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from inspect import signature from typing import ( TYPE_CHECKING, Annotated, Any, - Callable, Literal, - Optional, - Union, ) from pydantic import Field, SkipValidation @@ -44,9 +41,9 @@ class StructuredTool(BaseTool): ..., description="The tool schema." ) """The input arguments' schema.""" - func: Optional[Callable[..., Any]] = None + func: Callable[..., Any] | None = None """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[Any]]] = None + coroutine: Callable[..., Awaitable[Any]] | None = None """The asynchronous version of the function.""" # --- Runnable --- @@ -55,8 +52,8 @@ class StructuredTool(BaseTool): @override async def ainvoke( self, - input: Union[str, dict, ToolCall], - config: Optional[RunnableConfig] = None, + input: str | dict | ToolCall, + config: RunnableConfig | None = None, **kwargs: Any, ) -> Any: if not self.coroutine: @@ -71,7 +68,7 @@ class StructuredTool(BaseTool): self, *args: Any, config: RunnableConfig, - run_manager: Optional[CallbackManagerForToolRun] = None, + run_manager: CallbackManagerForToolRun | None = None, **kwargs: Any, ) -> Any: """Use the tool. @@ -98,7 +95,7 @@ class StructuredTool(BaseTool): self, *args: Any, config: RunnableConfig, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + run_manager: AsyncCallbackManagerForToolRun | None = None, **kwargs: Any, ) -> Any: """Use the tool asynchronously. @@ -128,12 +125,12 @@ class StructuredTool(BaseTool): @classmethod def from_function( cls, - func: Optional[Callable] = None, - coroutine: Optional[Callable[..., Awaitable[Any]]] = None, - name: Optional[str] = None, - description: Optional[str] = None, + func: Callable | None = None, + coroutine: Callable[..., Awaitable[Any]] | None = None, + name: str | None = None, + description: str | None = None, return_direct: bool = False, # noqa: FBT001,FBT002 - args_schema: Optional[ArgsSchema] = None, + args_schema: ArgsSchema | None = None, infer_schema: bool = True, # noqa: FBT001,FBT002 *, response_format: Literal["content", "content_and_artifact"] = "content", diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index e33f0339af7..bf8188f75c7 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -8,8 +8,6 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, - Optional, - Union, ) from typing_extensions import override @@ -57,10 +55,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Start a trace for an LLM run. @@ -98,10 +96,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): prompts: list[str], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Start a trace for an LLM run. @@ -138,9 +136,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Run: """Run on new LLM token. Only available when streaming is enabled. @@ -244,11 +242,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + run_type: str | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Start a trace for a chain run. @@ -288,7 +286,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> Run: """End a trace for a chain run. @@ -316,7 +314,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self, error: BaseException, *, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, run_id: UUID, **kwargs: Any, ) -> Run: @@ -346,11 +344,11 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): input_str: str, *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, - inputs: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> Run: """Start a trace for a tool run. @@ -436,10 +434,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): query: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Run when the Retriever starts running. @@ -565,10 +563,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Any: chat_model_run = self._create_chat_model_run( @@ -595,9 +593,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): prompts: list[str], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: llm_run = self._create_llm_run( @@ -617,9 +615,9 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: llm_run = self._llm_run_with_token_event( @@ -649,8 +647,8 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): response: LLMResult, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: llm_run = self._complete_llm_run( @@ -666,8 +664,8 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: llm_run = self._errored_llm_run( @@ -684,11 +682,11 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + run_type: str | None = None, + name: str | None = None, **kwargs: Any, ) -> None: chain_run = self._create_chain_run( @@ -711,7 +709,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: chain_run = self._complete_chain_run( @@ -727,7 +725,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, error: BaseException, *, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, run_id: UUID, **kwargs: Any, ) -> None: @@ -746,11 +744,11 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): input_str: str, *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, - inputs: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: tool_run = self._create_tool_run( @@ -787,8 +785,8 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: tool_run = self._errored_tool_run( @@ -805,10 +803,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): query: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> None: retriever_run = self._create_retrieval_run( @@ -832,8 +830,8 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: retrieval_run = self._errored_retrieval_run( @@ -852,8 +850,8 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): documents: Sequence[Document], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: retrieval_run = self._complete_retrieval_run( @@ -882,7 +880,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, run: Run, token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 5c68ffef5aa..c630026c3a9 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -8,8 +8,6 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, - Union, cast, ) from uuid import UUID @@ -30,21 +28,21 @@ if TYPE_CHECKING: # for backwards partial compatibility if this is imported by users but unused tracing_callback_var: Any = None -tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( +tracing_v2_callback_var: ContextVar[LangChainTracer | None] = ContextVar( "tracing_callback_v2", default=None ) -run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVar( +run_collector_var: ContextVar[RunCollectorCallbackHandler | None] = ContextVar( "run_collector", default=None ) @contextmanager def tracing_v2_enabled( - project_name: Optional[str] = None, + project_name: str | None = None, *, - example_id: Optional[Union[str, UUID]] = None, - tags: Optional[list[str]] = None, - client: Optional[LangSmithClient] = None, + example_id: str | UUID | None = None, + tags: list[str] | None = None, + client: LangSmithClient | None = None, ) -> Generator[LangChainTracer, None, None]: """Instruct LangChain to log all runs in context to LangSmith. @@ -107,9 +105,9 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]: def _get_trace_callbacks( - project_name: Optional[str] = None, - example_id: Optional[Union[str, UUID]] = None, - callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, + project_name: str | None = None, + example_id: str | UUID | None = None, + callback_manager: CallbackManager | AsyncCallbackManager | None = None, ) -> Callbacks: if _tracing_v2_is_enabled(): project_name_ = project_name or _get_tracer_project() @@ -133,7 +131,7 @@ def _get_trace_callbacks( return cb -def _tracing_v2_is_enabled() -> Union[bool, Literal["local"]]: +def _tracing_v2_is_enabled() -> bool | Literal["local"]: if tracing_v2_callback_var.get() is not None: return True return ls_utils.tracing_is_enabled() @@ -164,19 +162,19 @@ def _get_tracer_project() -> str: _configure_hooks: list[ tuple[ - ContextVar[Optional[BaseCallbackHandler]], + ContextVar[BaseCallbackHandler | None], bool, - Optional[type[BaseCallbackHandler]], - Optional[str], + type[BaseCallbackHandler] | None, + str | None, ] ] = [] def register_configure_hook( - context_var: ContextVar[Optional[Any]], + context_var: ContextVar[Any | None], inheritable: bool, # noqa: FBT001 - handle_class: Optional[type[BaseCallbackHandler]] = None, - env_var: Optional[str] = None, + handle_class: type[BaseCallbackHandler] | None = None, + env_var: str | None = None, ) -> None: """Register a configure hook. @@ -199,7 +197,7 @@ def register_configure_hook( ( # the typings of ContextVar do not have the generic arg set as covariant # so we have to cast it - cast("ContextVar[Optional[BaseCallbackHandler]]", context_var), + cast("ContextVar[BaseCallbackHandler | None]", context_var), inheritable, handle_class, env_var, diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 981a58d6ee3..e027d78a324 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -10,8 +10,6 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, - Union, cast, ) @@ -81,7 +79,7 @@ class _TracerCore(ABC): """Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed.""" @abstractmethod - def _persist_run(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: + def _persist_run(self, run: Run) -> Coroutine[Any, Any, None] | None: """Persist a run.""" @staticmethod @@ -102,7 +100,7 @@ class _TracerCore(ABC): except: # noqa: E722 return msg - def _start_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # type: ignore[return] + def _start_trace(self, run: Run) -> Coroutine[Any, Any, None] | None: # type: ignore[return] current_dotted_order = run.start_time.strftime("%Y%m%dT%H%M%S%fZ") + str(run.id) if run.parent_run_id: if parent := self.order_map.get(run.parent_run_id): @@ -126,9 +124,7 @@ class _TracerCore(ABC): self.order_map[run.id] = (run.trace_id, run.dotted_order) self.run_map[str(run.id)] = run - def _get_run( - self, run_id: UUID, run_type: Union[str, set[str], None] = None - ) -> Run: + def _get_run(self, run_id: UUID, run_type: str | set[str] | None = None) -> Run: try: run = self.run_map[str(run_id)] except KeyError as exc: @@ -136,7 +132,7 @@ class _TracerCore(ABC): raise TracerException(msg) from exc if isinstance(run_type, str): - run_types: Union[set[str], None] = {run_type} + run_types: set[str] | None = {run_type} else: run_types = run_type if run_types is not None and run.run_type not in run_types: @@ -152,10 +148,10 @@ class _TracerCore(ABC): serialized: dict[str, Any], messages: list[list[BaseMessage]], run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Create a chat model run.""" @@ -196,10 +192,10 @@ class _TracerCore(ABC): serialized: dict[str, Any], prompts: list[str], run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Create a llm run.""" @@ -224,8 +220,8 @@ class _TracerCore(ABC): self, token: str, run_id: UUID, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - parent_run_id: Optional[UUID] = None, # noqa: ARG002 + chunk: GenerationChunk | ChatGenerationChunk | None = None, + parent_run_id: UUID | None = None, # noqa: ARG002 ) -> Run: """Append token event to LLM run and return the run.""" llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) @@ -291,7 +287,7 @@ class _TracerCore(ABC): return llm_run def _errored_llm_run( - self, error: BaseException, run_id: UUID, response: Optional[LLMResult] = None + self, error: BaseException, run_id: UUID, response: LLMResult | None = None ) -> Run: llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run.error = self._get_stacktrace(error) @@ -319,11 +315,11 @@ class _TracerCore(ABC): serialized: dict[str, Any], inputs: dict[str, Any], run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + run_type: str | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Create a chain Run.""" @@ -370,7 +366,7 @@ class _TracerCore(ABC): self, outputs: dict[str, Any], run_id: UUID, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, ) -> Run: """Update a chain run with outputs and end time.""" chain_run = self._get_run(run_id) @@ -389,7 +385,7 @@ class _TracerCore(ABC): def _errored_chain_run( self, error: BaseException, - inputs: Optional[dict[str, Any]], + inputs: dict[str, Any] | None, run_id: UUID, ) -> Run: chain_run = self._get_run(run_id) @@ -405,11 +401,11 @@ class _TracerCore(ABC): serialized: dict[str, Any], input_str: str, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, - inputs: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> Run: """Create a tool run.""" @@ -472,10 +468,10 @@ class _TracerCore(ABC): serialized: dict[str, Any], query: str, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Create a retrieval run.""" @@ -532,7 +528,7 @@ class _TracerCore(ABC): """Return self copied.""" return self - def _end_trace(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _end_trace(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """End a trace for a run. Args: @@ -540,7 +536,7 @@ class _TracerCore(ABC): """ return None - def _on_run_create(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_run_create(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process a run upon creation. Args: @@ -548,7 +544,7 @@ class _TracerCore(ABC): """ return None - def _on_run_update(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_run_update(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process a run upon update. Args: @@ -556,7 +552,7 @@ class _TracerCore(ABC): """ return None - def _on_llm_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_llm_start(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the LLM Run upon start. Args: @@ -568,8 +564,8 @@ class _TracerCore(ABC): self, run: Run, # noqa: ARG002 token: str, # noqa: ARG002 - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], # noqa: ARG002 - ) -> Union[Coroutine[Any, Any, None], None]: + chunk: GenerationChunk | ChatGenerationChunk | None, # noqa: ARG002 + ) -> Coroutine[Any, Any, None] | None: """Process new LLM token. Args: @@ -579,7 +575,7 @@ class _TracerCore(ABC): """ return None - def _on_llm_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_llm_end(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the LLM Run. Args: @@ -587,7 +583,7 @@ class _TracerCore(ABC): """ return None - def _on_llm_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_llm_error(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the LLM Run upon error. Args: @@ -595,7 +591,7 @@ class _TracerCore(ABC): """ return None - def _on_chain_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_chain_start(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Chain Run upon start. Args: @@ -603,7 +599,7 @@ class _TracerCore(ABC): """ return None - def _on_chain_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_chain_end(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Chain Run. Args: @@ -611,7 +607,7 @@ class _TracerCore(ABC): """ return None - def _on_chain_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_chain_error(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Chain Run upon error. Args: @@ -619,7 +615,7 @@ class _TracerCore(ABC): """ return None - def _on_tool_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_tool_start(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Tool Run upon start. Args: @@ -627,7 +623,7 @@ class _TracerCore(ABC): """ return None - def _on_tool_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_tool_end(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Tool Run. Args: @@ -635,7 +631,7 @@ class _TracerCore(ABC): """ return None - def _on_tool_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_tool_error(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Tool Run upon error. Args: @@ -643,7 +639,7 @@ class _TracerCore(ABC): """ return None - def _on_chat_model_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_chat_model_start(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Chat Model Run upon start. Args: @@ -651,7 +647,7 @@ class _TracerCore(ABC): """ return None - def _on_retriever_start(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_retriever_start(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Retriever Run upon start. Args: @@ -659,7 +655,7 @@ class _TracerCore(ABC): """ return None - def _on_retriever_end(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_retriever_end(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Retriever Run. Args: @@ -667,7 +663,7 @@ class _TracerCore(ABC): """ return None - def _on_retriever_error(self, run: Run) -> Union[Coroutine[Any, Any, None], None]: # noqa: ARG002 + def _on_retriever_error(self, run: Run) -> Coroutine[Any, Any, None] | None: # noqa: ARG002 """Process the Retriever Run upon error. Args: diff --git a/libs/core/langchain_core/tracers/evaluation.py b/libs/core/langchain_core/tracers/evaluation.py index a1e9ffb4785..7744c229caf 100644 --- a/libs/core/langchain_core/tracers/evaluation.py +++ b/libs/core/langchain_core/tracers/evaluation.py @@ -6,7 +6,7 @@ import logging import threading import weakref from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from uuid import UUID import langsmith @@ -43,19 +43,19 @@ class EvaluatorCallbackHandler(BaseTracer): """ name: str = "evaluator_callback_handler" - example_id: Optional[UUID] = None + example_id: UUID | None = None """The example ID associated with the runs.""" client: langsmith.Client """The LangSmith client instance used for evaluating the runs.""" evaluators: Sequence[langsmith.RunEvaluator] = () """The sequence of run evaluators to be executed.""" - executor: Optional[ThreadPoolExecutor] = None + executor: ThreadPoolExecutor | None = None """The thread pool executor used for running the evaluators.""" futures: weakref.WeakSet[Future] = weakref.WeakSet() """The set of futures representing the running evaluators.""" skip_unfinished: bool = True """Whether to skip runs that are not finished or raised an error.""" - project_name: Optional[str] = None + project_name: str | None = None """The LangSmith project name to be organize eval chain runs under.""" logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] lock: threading.Lock @@ -63,11 +63,11 @@ class EvaluatorCallbackHandler(BaseTracer): def __init__( self, evaluators: Sequence[langsmith.RunEvaluator], - client: Optional[langsmith.Client] = None, - example_id: Optional[Union[UUID, str]] = None, + client: langsmith.Client | None = None, + example_id: UUID | str | None = None, skip_unfinished: bool = True, # noqa: FBT001,FBT002 - project_name: Optional[str] = "evaluators", - max_concurrency: Optional[int] = None, + project_name: str | None = "evaluators", + max_concurrency: int | None = None, **kwargs: Any, ) -> None: """Create an EvaluatorCallbackHandler. @@ -156,7 +156,7 @@ class EvaluatorCallbackHandler(BaseTracer): def _select_eval_results( self, - results: Union[EvaluationResult, EvaluationResults], + results: EvaluationResult | EvaluationResults, ) -> list[EvaluationResult]: if isinstance(results, EvaluationResult): results_ = [results] @@ -172,9 +172,9 @@ class EvaluatorCallbackHandler(BaseTracer): def _log_evaluation_feedback( self, - evaluator_response: Union[EvaluationResult, EvaluationResults], + evaluator_response: EvaluationResult | EvaluationResults, run: Run, - source_run_id: Optional[UUID] = None, + source_run_id: UUID | None = None, ) -> list[EvaluationResult]: results = self._select_eval_results(evaluator_response) for res in results: diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 80019dc7ef6..064dd8f6663 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -8,10 +8,8 @@ import logging from typing import ( TYPE_CHECKING, Any, - Optional, TypedDict, TypeVar, - Union, cast, ) from uuid import UUID, uuid4 @@ -72,11 +70,11 @@ class RunInfo(TypedDict): """The type of the run.""" inputs: NotRequired[Any] """The inputs to the run.""" - parent_run_id: Optional[UUID] + parent_run_id: UUID | None """The ID of the parent run.""" -def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str: +def _assign_name(name: str | None, serialized: dict[str, Any] | None) -> str: """Assign a name to a run.""" if name is not None: return name @@ -97,12 +95,12 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand def __init__( self, *args: Any, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> None: """Initialize the tracer.""" @@ -116,7 +114,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand # of a child run, which results in clean up of run_map. # So we keep track of the mapping between children and parent run IDs # in a separate container. This container is GCed when the tracer is GCed. - self.parent_map: dict[UUID, Optional[UUID]] = {} + self.parent_map: dict[UUID, UUID | None] = {} self.is_tapped: dict[UUID, Any] = {} @@ -277,9 +275,9 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self, run_id: UUID, *, - tags: Optional[list[str]], - metadata: Optional[dict[str, Any]], - parent_run_id: Optional[UUID], + tags: list[str] | None, + metadata: dict[str, Any] | None, + parent_run_id: UUID | None, name_: str, run_type: str, **kwargs: Any, @@ -309,10 +307,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> None: """Start a trace for a chat model run.""" @@ -351,10 +349,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand prompts: list[str], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> None: """Start a trace for a (non-chat model) LLM run.""" @@ -395,8 +393,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand data: Any, *, run_id: UUID, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Generate a custom astream event.""" @@ -416,9 +414,9 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: """Run on new output token. Only available when streaming is enabled. @@ -426,7 +424,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand For both chat models and non-chat models (legacy LLMs). """ run_info = self.run_map.get(run_id) - chunk_: Union[GenerationChunk, BaseMessageChunk] + chunk_: GenerationChunk | BaseMessageChunk if run_info is None: msg = f"Run ID {run_id} not found in run map." @@ -480,8 +478,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_info = self.run_map.pop(run_id) inputs_ = run_info.get("inputs") - generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]] - output: Union[dict, BaseMessage] = {} + generations: list[list[GenerationChunk]] | list[list[ChatGenerationChunk]] + output: dict | BaseMessage = {} if run_info["run_type"] == "chat_model": generations = cast("list[list[ChatGenerationChunk]]", response.generations) @@ -533,11 +531,11 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - run_type: Optional[str] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + run_type: str | None = None, + name: str | None = None, **kwargs: Any, ) -> None: """Start a trace for a chain run.""" @@ -581,7 +579,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """End a trace for a chain run.""" @@ -639,11 +637,11 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand input_str: str, *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, - inputs: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, + inputs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Start a trace for a tool run.""" @@ -680,8 +678,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when tool errors.""" @@ -735,10 +733,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand query: str, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> None: """Run when Retriever starts running.""" @@ -807,14 +805,14 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def _astream_events_implementation_v1( runnable: Runnable[Input, Output], value: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> AsyncIterator[StandardStreamEvent]: stream = LogStreamCallbackHandler( @@ -983,14 +981,14 @@ async def _astream_events_implementation_v1( async def _astream_events_implementation_v2( runnable: Runnable[Input, Output], value: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, **kwargs: Any, ) -> AsyncIterator[StandardStreamEvent]: """Implementation of the astream events API for V2 runnables.""" diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 431c4564a85..d1bd4c89b1e 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from uuid import UUID from langsmith import Client, get_tracing_context @@ -30,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) _LOGGED = set() -_EXECUTOR: Optional[ThreadPoolExecutor] = None +_EXECUTOR: ThreadPoolExecutor | None = None def log_error_once(method: str, exception: Exception) -> None: @@ -76,10 +76,10 @@ class LangChainTracer(BaseTracer): def __init__( self, - example_id: Optional[Union[UUID, str]] = None, - project_name: Optional[str] = None, - client: Optional[Client] = None, - tags: Optional[list[str]] = None, + example_id: UUID | str | None = None, + project_name: str | None = None, + client: Client | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Initialize the LangChain tracer. @@ -98,7 +98,7 @@ class LangChainTracer(BaseTracer): self.project_name = project_name or ls_utils.get_tracer_project() self.client = client or get_client() self.tags = tags or [] - self.latest_run: Optional[Run] = None + self.latest_run: Run | None = None self.run_has_token_event_map: dict[str, bool] = {} def _start_trace(self, run: Run) -> None: @@ -122,10 +122,10 @@ class LangChainTracer(BaseTracer): messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[list[str]] = None, - parent_run_id: Optional[UUID] = None, - metadata: Optional[dict[str, Any]] = None, - name: Optional[str] = None, + tags: list[str] | None = None, + parent_run_id: UUID | None = None, + metadata: dict[str, Any] | None = None, + name: str | None = None, **kwargs: Any, ) -> Run: """Start a trace for an LLM run. @@ -242,8 +242,8 @@ class LangChainTracer(BaseTracer): self, token: str, run_id: UUID, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, - parent_run_id: Optional[UUID] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, + parent_run_id: UUID | None = None, ) -> Run: run_id_str = str(run_id) if run_id_str not in self.run_has_token_event_map: diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 923f0275683..8822bc92d91 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -12,9 +12,7 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, TypeVar, - Union, overload, ) @@ -59,13 +57,13 @@ class LogEntry(TypedDict): """List of LLM tokens streamed by this run, if applicable.""" streamed_output: list[Any] """List of output chunks streamed by this run, if available.""" - inputs: NotRequired[Optional[Any]] + inputs: NotRequired[Any | None] """Inputs to this run. Not available currently via astream_log.""" - final_output: Optional[Any] + final_output: Any | None """Final output of this run. Only available after the run has finished successfully.""" - end_time: Optional[str] + end_time: str | None """ISO-8601 timestamp of when the run ended. Only available after the run has finished.""" @@ -77,7 +75,7 @@ class RunState(TypedDict): """ID of the run.""" streamed_output: list[Any] """List of output chunks streamed by Runnable.stream()""" - final_output: Optional[Any] + final_output: Any | None """Final output of the run, usually the result of aggregating (`+`) streamed_output. Updated throughout the run when supported by the Runnable.""" @@ -112,7 +110,7 @@ class RunLogPatch: """ self.ops = list(ops) - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + def __add__(self, other: RunLogPatch | Any) -> RunLog: """Combine two ``RunLogPatch`` instances. Args: @@ -160,7 +158,7 @@ class RunLog(RunLogPatch): super().__init__(*ops) self.state = state - def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: + def __add__(self, other: RunLogPatch | Any) -> RunLog: """Combine two ``RunLog``s. Args: @@ -215,12 +213,12 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self, *, auto_close: bool = True, - include_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[str]] = None, - include_tags: Optional[Sequence[str]] = None, - exclude_names: Optional[Sequence[str]] = None, - exclude_types: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, + include_names: Sequence[str] | None = None, + include_types: Sequence[str] | None = None, + include_tags: Sequence[str] | None = None, + exclude_names: Sequence[str] | None = None, + exclude_types: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, # Schema format is for internal use only. _schema_format: Literal["original", "streaming_events"] = "streaming_events", ) -> None: @@ -273,7 +271,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self.receive_stream = memory_stream.get_receive_stream() self._key_map_by_run_id: dict[UUID, str] = {} self._counter_map_by_name: dict[str, int] = defaultdict(int) - self.root_id: Optional[UUID] = None + self.root_id: UUID | None = None def __aiter__(self) -> AsyncIterator[RunLogPatch]: """Iterate over the stream of run logs. @@ -515,7 +513,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self, run: Run, token: str, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" index = self._key_map_by_run_id.get(run.id) @@ -541,7 +539,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): def _get_standardized_inputs( run: Run, schema_format: Literal["original", "streaming_events"] -) -> Optional[dict[str, Any]]: +) -> dict[str, Any] | None: """Extract standardized inputs from a run. Standardizes the inputs based on the type of the runnable used. @@ -583,7 +581,7 @@ def _get_standardized_inputs( def _get_standardized_outputs( run: Run, schema_format: Literal["original", "streaming_events", "original+chat"] -) -> Optional[Any]: +) -> Any | None: """Extract standardized output from a run. Standardizes the outputs based on the type of the runnable used. @@ -617,7 +615,7 @@ def _get_standardized_outputs( def _astream_log_implementation( runnable: Runnable[Input, Output], value: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, stream: LogStreamCallbackHandler, diff: Literal[True] = True, @@ -630,7 +628,7 @@ def _astream_log_implementation( def _astream_log_implementation( runnable: Runnable[Input, Output], value: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, stream: LogStreamCallbackHandler, diff: Literal[False], @@ -642,13 +640,13 @@ def _astream_log_implementation( async def _astream_log_implementation( runnable: Runnable[Input, Output], value: Any, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, *, stream: LogStreamCallbackHandler, diff: bool = True, with_streamed_output_list: bool = True, **kwargs: Any, -) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: +) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]: """Implementation of astream_log for a given runnable. The implementation has been factored out (at least temporarily) as both @@ -693,8 +691,8 @@ async def _astream_log_implementation( # add each chunk to the output stream async def consume_astream() -> None: try: - prev_final_output: Optional[Output] = None - final_output: Optional[Output] = None + prev_final_output: Output | None = None + final_output: Output | None = None async for chunk in runnable.astream(value, config, **kwargs): prev_final_output = final_output diff --git a/libs/core/langchain_core/tracers/root_listeners.py b/libs/core/langchain_core/tracers/root_listeners.py index 5b07553631a..043805c16d4 100644 --- a/libs/core/langchain_core/tracers/root_listeners.py +++ b/libs/core/langchain_core/tracers/root_listeners.py @@ -1,7 +1,7 @@ """Tracers that call listeners.""" -from collections.abc import Awaitable -from typing import TYPE_CHECKING, Callable, Optional, Union +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING from langchain_core.runnables.config import ( RunnableConfig, @@ -14,10 +14,10 @@ from langchain_core.tracers.schemas import Run if TYPE_CHECKING: from uuid import UUID -Listener = Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] -AsyncListener = Union[ - Callable[[Run], Awaitable[None]], Callable[[Run, RunnableConfig], Awaitable[None]] -] +Listener = Callable[[Run], None] | Callable[[Run, RunnableConfig], None] +AsyncListener = ( + Callable[[Run], Awaitable[None]] | Callable[[Run, RunnableConfig], Awaitable[None]] +) class RootListenersTracer(BaseTracer): @@ -30,9 +30,9 @@ class RootListenersTracer(BaseTracer): self, *, config: RunnableConfig, - on_start: Optional[Listener], - on_end: Optional[Listener], - on_error: Optional[Listener], + on_start: Listener | None, + on_end: Listener | None, + on_error: Listener | None, ) -> None: """Initialize the tracer. @@ -48,7 +48,7 @@ class RootListenersTracer(BaseTracer): self._arg_on_start = on_start self._arg_on_end = on_end self._arg_on_error = on_error - self.root_id: Optional[UUID] = None + self.root_id: UUID | None = None def _persist_run(self, run: Run) -> None: # This is a legacy method only called once for an entire run tree @@ -85,9 +85,9 @@ class AsyncRootListenersTracer(AsyncBaseTracer): self, *, config: RunnableConfig, - on_start: Optional[AsyncListener], - on_end: Optional[AsyncListener], - on_error: Optional[AsyncListener], + on_start: AsyncListener | None, + on_end: AsyncListener | None, + on_error: AsyncListener | None, ) -> None: """Initialize the tracer. @@ -103,7 +103,7 @@ class AsyncRootListenersTracer(AsyncBaseTracer): self._arg_on_start = on_start self._arg_on_end = on_end self._arg_on_error = on_error - self.root_id: Optional[UUID] = None + self.root_id: UUID | None = None async def _persist_run(self, run: Run) -> None: # This is a legacy method only called once for an entire run tree diff --git a/libs/core/langchain_core/tracers/run_collector.py b/libs/core/langchain_core/tracers/run_collector.py index afac1ad3267..8c53133642f 100644 --- a/libs/core/langchain_core/tracers/run_collector.py +++ b/libs/core/langchain_core/tracers/run_collector.py @@ -1,6 +1,6 @@ """A tracer that collects all nested runs in a list.""" -from typing import Any, Optional, Union +from typing import Any from uuid import UUID from langchain_core.tracers.base import BaseTracer @@ -15,9 +15,7 @@ class RunCollectorCallbackHandler(BaseTracer): name: str = "run-collector_callback_handler" - def __init__( - self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any - ) -> None: + def __init__(self, example_id: UUID | str | None = None, **kwargs: Any) -> None: """Initialize the RunCollectorCallbackHandler. Args: diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index cd3dd71fea7..4ec415a5e34 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -4,7 +4,7 @@ from __future__ import annotations import warnings from datetime import datetime, timezone -from typing import Any, Optional +from typing import Any from uuid import UUID from langsmith import RunTree @@ -37,8 +37,8 @@ class TracerSessionV1Base(BaseModelV1): """Base class for TracerSessionV1.""" start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc)) - name: Optional[str] = None - extra: Optional[dict[str, Any]] = None + name: str | None = None + extra: dict[str, Any] | None = None @deprecated("0.1.0", removal="1.0") @@ -72,15 +72,15 @@ class BaseRun(BaseModelV1): """Base class for Run.""" uuid: str - parent_uuid: Optional[str] = None + parent_uuid: str | None = None start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc)) end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc)) - extra: Optional[dict[str, Any]] = None + extra: dict[str, Any] | None = None execution_order: int child_execution_order: int serialized: dict[str, Any] session_id: int - error: Optional[str] = None + error: str | None = None @deprecated("0.1.0", alternative="Run", removal="1.0") @@ -97,7 +97,7 @@ class ChainRun(BaseRun): """Class for ChainRun.""" inputs: dict[str, Any] - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None child_llm_runs: list[LLMRun] = FieldV1(default_factory=list) child_chain_runs: list[ChainRun] = FieldV1(default_factory=list) child_tool_runs: list[ToolRun] = FieldV1(default_factory=list) @@ -108,7 +108,7 @@ class ToolRun(BaseRun): """Class for ToolRun.""" tool_input: str - output: Optional[str] = None + output: str | None = None action: str child_llm_runs: list[LLMRun] = FieldV1(default_factory=list) child_chain_runs: list[ChainRun] = FieldV1(default_factory=list) diff --git a/libs/core/langchain_core/tracers/stdout.py b/libs/core/langchain_core/tracers/stdout.py index b69a621491d..72e6cee19a6 100644 --- a/libs/core/langchain_core/tracers/stdout.py +++ b/libs/core/langchain_core/tracers/stdout.py @@ -1,7 +1,8 @@ """Tracers that print to the console.""" import json -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 7b8465e8d02..2559c09015a 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]: @@ -80,7 +80,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any] return merged -def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]: +def merge_lists(left: list | None, *others: list | None) -> list | None: """Add many lists, handling None. Args: diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 3a1ae3b1ace..8c7598be92d 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -11,17 +11,15 @@ from collections.abc import ( AsyncIterable, AsyncIterator, Awaitable, + Callable, Iterator, ) from contextlib import AbstractAsyncContextManager from types import TracebackType from typing import ( Any, - Callable, Generic, - Optional, TypeVar, - Union, cast, overload, ) @@ -36,8 +34,8 @@ _no_default = object() # https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54 # before 3.10, the builtin anext() was not available def py_anext( - iterator: AsyncIterator[T], default: Union[T, Any] = _no_default -) -> Awaitable[Union[T, Any, None]]: + iterator: AsyncIterator[T], default: T | Any = _no_default +) -> Awaitable[T | Any | None]: """Pure-Python implementation of anext() for testing purposes. Closely matches the builtin anext() C implementation. @@ -68,7 +66,7 @@ def py_anext( if default is _no_default: return __anext__(iterator) - async def anext_impl() -> Union[T, Any]: + async def anext_impl() -> T | Any: try: # The C code is way more low-level than this, as it implements # all methods of the iterator protocol. In this implementation @@ -90,9 +88,9 @@ class NoLock: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> bool: """Return False, exception not suppressed.""" return False @@ -197,7 +195,7 @@ class Tee(Generic[T]): iterable: AsyncIterator[T], n: int = 2, *, - lock: Optional[AbstractAsyncContextManager[Any]] = None, + lock: AbstractAsyncContextManager[Any] | None = None, ): """Create a ``tee``. @@ -230,8 +228,8 @@ class Tee(Generic[T]): def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ... def __getitem__( - self, item: Union[int, slice] - ) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]: + self, item: int | slice + ) -> AsyncIterator[T] | tuple[AsyncIterator[T], ...]: """Return the child iterator(s) for the given index or slice.""" return self._children[item] @@ -249,9 +247,9 @@ class Tee(Generic[T]): async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> bool: """Close all child iterators. @@ -292,9 +290,7 @@ class aclosing(AbstractAsyncContextManager): # noqa: N801 """ - def __init__( - self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]] - ) -> None: + def __init__(self, thing: AsyncGenerator[Any, Any] | AsyncIterator[Any]) -> None: """Create the context manager. Args: @@ -303,15 +299,15 @@ class aclosing(AbstractAsyncContextManager): # noqa: N801 self.thing = thing @override - async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]: + async def __aenter__(self) -> AsyncGenerator[Any, Any] | AsyncIterator[Any]: return self.thing @override async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if hasattr(self.thing, "aclose"): await self.thing.aclose() diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index 51fd3d66e21..0256518cce1 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Optional, Union +from typing import Any def env_var_is_set(env_var: str) -> bool: @@ -25,9 +25,9 @@ def env_var_is_set(env_var: str) -> bool: def get_from_dict_or_env( data: dict[str, Any], - key: Union[str, list[str]], + key: str | list[str], env_key: str, - default: Optional[str] = None, + default: str | None = None, ) -> str: """Get a value from a dictionary or an environment variable. @@ -56,7 +56,7 @@ def get_from_dict_or_env( return get_from_env(key_for_err, env_key, default=default) -def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: +def get_from_env(key: str, env_key: str, default: str | None = None) -> str: """Get a value from a dictionary or an environment variable. Args: diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index a17ad463b17..47189f4624c 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -8,13 +8,12 @@ import logging import types import typing import uuid +from collections.abc import Callable from typing import ( TYPE_CHECKING, Annotated, Any, - Callable, Literal, - Optional, Union, cast, get_args, @@ -103,8 +102,8 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict: def _convert_json_schema_to_openai_function( schema: dict, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, rm_titles: bool = True, ) -> FunctionDescription: """Converts a Pydantic model to a function description for the OpenAI API. @@ -137,8 +136,8 @@ def _convert_json_schema_to_openai_function( def _convert_pydantic_to_openai_function( model: type, *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, rm_titles: bool = True, ) -> FunctionDescription: """Converts a Pydantic model to a function description for the OpenAI API. @@ -184,8 +183,8 @@ convert_pydantic_to_openai_function = deprecated( def convert_pydantic_to_openai_tool( model: type[BaseModel], *, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, ) -> ToolDescription: """Converts a Pydantic model to a function description for the OpenAI API. @@ -285,7 +284,9 @@ def _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic( annotated_args[0], depth=depth + 1, visited=visited ) - field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) + field_kwargs = dict( + zip(("default", "description"), annotated_args[1:], strict=False) + ) if (field_desc := field_kwargs.get("description")) and not isinstance( field_desc, str ): @@ -393,9 +394,9 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: def convert_to_openai_function( - function: Union[dict[str, Any], type, Callable, BaseTool], + function: dict[str, Any] | type | Callable | BaseTool, *, - strict: Optional[bool] = None, + strict: bool | None = None, ) -> dict[str, Any]: """Convert a raw function/class to an OpenAI function. @@ -520,9 +521,9 @@ _WellKnownOpenAITools = ( def convert_to_openai_tool( - tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool], + tool: dict[str, Any] | type[BaseModel] | Callable | BaseTool, *, - strict: Optional[bool] = None, + strict: bool | None = None, ) -> dict[str, Any]: """Convert a tool-like object to an OpenAI tool schema. @@ -592,9 +593,9 @@ def convert_to_openai_tool( def convert_to_json_schema( - schema: Union[dict[str, Any], type[BaseModel], Callable, BaseTool], + schema: dict[str, Any] | type[BaseModel] | Callable | BaseTool, *, - strict: Optional[bool] = None, + strict: bool | None = None, ) -> dict[str, Any]: """Convert a schema representation to a JSON schema. @@ -637,9 +638,9 @@ def convert_to_json_schema( def tool_example_to_messages( input: str, tool_calls: list[BaseModel], - tool_outputs: Optional[list[str]] = None, + tool_outputs: list[str] | None = None, *, - ai_response: Optional[str] = None, + ai_response: str | None = None, ) -> list[BaseMessage]: """Convert an example into a list of messages that can be fed into an LLM. @@ -731,7 +732,7 @@ def tool_example_to_messages( tool_outputs = tool_outputs or ["You have correctly called this tool."] * len( openai_tool_calls ) - for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): + for output, tool_call_dict in zip(tool_outputs, openai_tool_calls, strict=False): messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) if ai_response: @@ -740,7 +741,7 @@ def tool_example_to_messages( def _parse_google_docstring( - docstring: Optional[str], + docstring: str | None, args: list[str], *, error_on_invalid_docstring: bool = False, diff --git a/libs/core/langchain_core/utils/html.py b/libs/core/langchain_core/utils/html.py index b0a9880c5d5..a0054e613b3 100644 --- a/libs/core/langchain_core/utils/html.py +++ b/libs/core/langchain_core/utils/html.py @@ -3,7 +3,6 @@ import logging import re from collections.abc import Sequence -from typing import Optional, Union from urllib.parse import urljoin, urlparse logger = logging.getLogger(__name__) @@ -35,7 +34,7 @@ DEFAULT_LINK_REGEX = ( def find_all_links( - raw_html: str, *, pattern: Union[str, re.Pattern, None] = None + raw_html: str, *, pattern: str | re.Pattern | None = None ) -> list[str]: """Extract all links from a raw HTML string. @@ -54,8 +53,8 @@ def extract_sub_links( raw_html: str, url: str, *, - base_url: Optional[str] = None, - pattern: Union[str, re.Pattern, None] = None, + base_url: str | None = None, + pattern: str | re.Pattern | None = None, prevent_outside: bool = True, exclude_prefixes: Sequence[str] = (), continue_on_failure: bool = False, diff --git a/libs/core/langchain_core/utils/input.py b/libs/core/langchain_core/utils/input.py index afa3bf758c7..b5f9fea712c 100644 --- a/libs/core/langchain_core/utils/input.py +++ b/libs/core/langchain_core/utils/input.py @@ -1,6 +1,6 @@ """Handle chained inputs.""" -from typing import Optional, TextIO +from typing import TextIO _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -12,7 +12,7 @@ _TEXT_COLOR_MAPPING = { def get_color_mapping( - items: list[str], excluded_colors: Optional[list] = None + items: list[str], excluded_colors: list | None = None ) -> dict[str, str]: """Get mapping for items to a support color. @@ -56,7 +56,7 @@ def get_bolded_text(text: str) -> str: def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None + text: str, color: str | None = None, end: str = "", file: TextIO | None = None ) -> None: """Print text with highlighting and no end characters. diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 9dd5614e67c..3ed9d13e375 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -9,9 +9,7 @@ from typing import ( Any, Generic, Literal, - Optional, TypeVar, - Union, overload, ) @@ -26,9 +24,9 @@ class NoLock: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> Literal[False]: """Return False (exception not suppressed).""" return False @@ -133,7 +131,7 @@ class Tee(Generic[T]): iterable: Iterator[T], n: int = 2, *, - lock: Optional[AbstractContextManager[Any]] = None, + lock: AbstractContextManager[Any] | None = None, ): """Create a ``tee``. @@ -165,9 +163,7 @@ class Tee(Generic[T]): @overload def __getitem__(self, item: slice) -> tuple[Iterator[T], ...]: ... - def __getitem__( - self, item: Union[int, slice] - ) -> Union[Iterator[T], tuple[Iterator[T], ...]]: + def __getitem__(self, item: int | slice) -> Iterator[T] | tuple[Iterator[T], ...]: """Return the child iterator(s) at the given index or slice.""" return self._children[item] @@ -185,9 +181,9 @@ class Tee(Generic[T]): def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> Literal[False]: """Close all child iterators. @@ -207,7 +203,7 @@ class Tee(Generic[T]): safetee = Tee -def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[list[T]]: +def batch_iterate(size: int | None, iterable: Iterable[T]) -> Iterator[list[T]]: """Utility batching function. Args: diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 2703142884a..c7e28f360df 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -4,7 +4,8 @@ from __future__ import annotations import json import re -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any from langchain_core.exceptions import OutputParserException @@ -19,7 +20,7 @@ def _replace_new_line(match: re.Match[str]) -> str: return match.group(1) + value + match.group(3) -def _custom_parser(multiline_string: Union[str, bytes, bytearray]) -> str: +def _custom_parser(multiline_string: str | bytes | bytearray) -> str: r"""Custom parser for multiline strings. The LLM response for `action_input` may be a multiline diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index 1baf93610b6..939ac06eb01 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -3,13 +3,13 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from collections.abc import Sequence -def _retrieve_ref(path: str, schema: dict) -> Union[list, dict]: +def _retrieve_ref(path: str, schema: dict) -> list | dict: components = path.split("/") if components[0] != "#": msg = ( @@ -17,7 +17,7 @@ def _retrieve_ref(path: str, schema: dict) -> Union[list, dict]: "with #." ) raise ValueError(msg) - out: Union[list, dict] = schema + out: list | dict = schema for component in components[1:]: if component in out: if isinstance(out, list): @@ -67,7 +67,7 @@ def _process_dict_properties( def _dereference_refs_helper( obj: Any, full_schema: dict[str, Any], - processed_refs: Optional[set[str]], + processed_refs: set[str] | None, skip_keys: Sequence[str], shallow_refs: bool, # noqa: FBT001 ) -> Any: @@ -167,8 +167,8 @@ def _dereference_refs_helper( def dereference_refs( schema_obj: dict, *, - full_schema: Optional[dict] = None, - skip_keys: Optional[Sequence[str]] = None, + full_schema: dict | None = None, + skip_keys: Sequence[str] | None = None, ) -> dict: """Resolve and inline JSON Schema $ref references in a schema object. diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 76ad959da06..cfe0a3e82eb 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -12,8 +12,6 @@ from typing import ( TYPE_CHECKING, Any, Literal, - Optional, - Union, cast, ) @@ -23,7 +21,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -Scopes: TypeAlias = list[Union[Literal[False, 0], Mapping[str, Any]]] +Scopes: TypeAlias = list[Literal[False, 0] | Mapping[str, Any]] # Globals @@ -433,13 +431,13 @@ EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) def render( - template: Union[str, list[tuple[str, str]]] = "", + template: str | list[tuple[str, str]] = "", data: Mapping[str, Any] = EMPTY_DICT, partials_dict: Mapping[str, str] = EMPTY_DICT, padding: str = "", def_ldel: str = "{{", def_rdel: str = "}}", - scopes: Optional[Scopes] = None, + scopes: Scopes | None = None, warn: bool = False, # noqa: FBT001,FBT002 keep: bool = False, # noqa: FBT001,FBT002 ) -> str: diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 64a5b284806..792b3596f5c 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -5,16 +5,14 @@ from __future__ import annotations import inspect import textwrap import warnings +from collections.abc import Callable from contextlib import nullcontext from functools import lru_cache, wraps from types import GenericAlias from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, TypeVar, - Union, cast, overload, ) @@ -205,8 +203,8 @@ def _create_subset_model_v1( model: type[BaseModelV1], field_names: list, *, - descriptions: Optional[dict] = None, - fn_description: Optional[str] = None, + descriptions: dict | None = None, + fn_description: str | None = None, ) -> type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} @@ -218,7 +216,7 @@ def _create_subset_model_v1( # this isn't perfect but should work for most functions field.outer_type_ if field.required and not field.allow_none - else Optional[field.outer_type_] + else field.outer_type_ | None ) if descriptions and field_name in descriptions: field.field_info.description = descriptions[field_name] @@ -234,8 +232,8 @@ def _create_subset_model_v2( model: type[BaseModel], field_names: list[str], *, - descriptions: Optional[dict] = None, - fn_description: Optional[str] = None, + descriptions: dict | None = None, + fn_description: str | None = None, ) -> type[BaseModel]: """Create a pydantic model with a subset of the model fields.""" descriptions_ = descriptions or {} @@ -276,8 +274,8 @@ def _create_subset_model( model: TypeBaseModel, field_names: list[str], *, - descriptions: Optional[dict] = None, - fn_description: Optional[str] = None, + descriptions: dict | None = None, + fn_description: str | None = None, ) -> type[BaseModel]: """Create subset model using the same pydantic version as the input model. @@ -318,8 +316,8 @@ def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ... def get_fields( - model: Union[type[Union[BaseModel, BaseModelV1]], BaseModel, BaseModelV1], -) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]: + model: type[BaseModel | BaseModelV1] | BaseModel | BaseModelV1, +) -> dict[str, FieldInfoV2] | dict[str, ModelField]: """Return the field names of a Pydantic model. Args: @@ -348,7 +346,7 @@ NO_DEFAULT = object() def _create_root_model( name: str, type_: Any, - module_name: Optional[str] = None, + module_name: str | None = None, default_: object = NO_DEFAULT, ) -> type[BaseModel]: """Create a base class.""" @@ -413,7 +411,7 @@ def _create_root_model_cached( model_name: str, type_: Any, *, - module_name: Optional[str] = None, + module_name: str | None = None, default_: object = NO_DEFAULT, ) -> type[BaseModel]: return _create_root_model( @@ -436,7 +434,7 @@ def _create_model_cached( def create_model( model_name: str, - module_name: Optional[str] = None, + module_name: str | None = None, /, **field_definitions: Any, ) -> type[BaseModel]: @@ -509,9 +507,9 @@ def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any def create_model_v2( model_name: str, *, - module_name: Optional[str] = None, - field_definitions: Optional[dict[str, Any]] = None, - root: Optional[Any] = None, + module_name: str | None = None, + field_definitions: dict[str, Any] | None = None, + root: Any | None = None, ) -> type[BaseModel]: """Create a pydantic model with the given field definitions. diff --git a/libs/core/langchain_core/utils/usage.py b/libs/core/langchain_core/utils/usage.py index 95ee2ca063e..b60b173cdf4 100644 --- a/libs/core/langchain_core/utils/usage.py +++ b/libs/core/langchain_core/utils/usage.py @@ -1,6 +1,6 @@ """Usage utilities.""" -from typing import Callable +from collections.abc import Callable def _dict_int_op( diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 7b1bd670baa..06c16ca4762 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -6,9 +6,9 @@ import functools import importlib import os import warnings -from collections.abc import Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence from importlib.metadata import version -from typing import Any, Callable, Optional, Union, overload +from typing import Any, overload from uuid import uuid4 from packaging.version import parse @@ -91,7 +91,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]: @classmethod @override - def now(cls, tz: Union[datetime.tzinfo, None] = None) -> "MockDateTime": + def now(cls, tz: datetime.tzinfo | None = None) -> "MockDateTime": # Create a copy of dt_value. return MockDateTime( dt_value.year, @@ -113,7 +113,7 @@ def mock_now(dt_value: datetime.datetime) -> Iterator[type]: def guard_import( - module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None + module_name: str, *, pip_name: str | None = None, package: str | None = None ) -> Any: """Dynamically import a module. @@ -146,10 +146,10 @@ def guard_import( def check_package_version( package: str, - lt_version: Optional[str] = None, - lte_version: Optional[str] = None, - gt_version: Optional[str] = None, - gte_version: Optional[str] = None, + lt_version: str | None = None, + lte_version: str | None = None, + gt_version: str | None = None, + gte_version: str | None = None, ) -> None: """Check the version of a package. @@ -306,7 +306,7 @@ def build_extra_kwargs( return extra_kwargs -def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: +def convert_to_secret_str(value: SecretStr | str) -> SecretStr: """Convert a string to a SecretStr if needed. Args: @@ -345,29 +345,29 @@ def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ... @overload def from_env( - key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str] + key: str | Sequence[str], /, *, default: str, error_message: str | None ) -> Callable[[], str]: ... @overload def from_env( - key: str, /, *, default: None, error_message: Optional[str] -) -> Callable[[], Optional[str]]: ... + key: str, /, *, default: None, error_message: str | None +) -> Callable[[], str | None]: ... @overload def from_env( - key: Union[str, Sequence[str]], /, *, default: None -) -> Callable[[], Optional[str]]: ... + key: str | Sequence[str], /, *, default: None +) -> Callable[[], str | None]: ... def from_env( - key: Union[str, Sequence[str]], + key: str | Sequence[str], /, *, - default: Union[str, _NoDefaultType, None] = _NoDefault, - error_message: Optional[str] = None, -) -> Union[Callable[[], str], Callable[[], Optional[str]]]: + default: str | _NoDefaultType | None = _NoDefault, + error_message: str | None = None, +) -> Callable[[], str] | Callable[[], str | None]: """Create a factory method that gets a value from an environment variable. Args: @@ -384,7 +384,7 @@ def from_env( factory method that will look up the value from the environment. """ - def get_from_env_fn() -> Optional[str]: + def get_from_env_fn() -> str | None: """Get a value from an environment variable. Raises: @@ -416,7 +416,7 @@ def from_env( @overload -def secret_from_env(key: Union[str, Sequence[str]], /) -> Callable[[], SecretStr]: ... +def secret_from_env(key: str | Sequence[str], /) -> Callable[[], SecretStr]: ... @overload @@ -425,8 +425,8 @@ def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: .. @overload def secret_from_env( - key: Union[str, Sequence[str]], /, *, default: None -) -> Callable[[], Optional[SecretStr]]: ... + key: str | Sequence[str], /, *, default: None +) -> Callable[[], SecretStr | None]: ... @overload @@ -434,12 +434,12 @@ def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretSt def secret_from_env( - key: Union[str, Sequence[str]], + key: str | Sequence[str], /, *, - default: Union[str, _NoDefaultType, None] = _NoDefault, - error_message: Optional[str] = None, -) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]: + default: str | _NoDefaultType | None = _NoDefault, + error_message: str | None = None, +) -> Callable[[], SecretStr | None] | Callable[[], SecretStr]: """Secret from env. Args: @@ -453,7 +453,7 @@ def secret_from_env( factory method that will look up the secret from the environment. """ - def get_secret_from_env() -> Optional[SecretStr]: + def get_secret_from_env() -> SecretStr | None: """Get a value from an environment variable. Raises: @@ -498,7 +498,7 @@ Used for: """ -def ensure_id(id_val: Optional[str]) -> str: +def ensure_id(id_val: str | None) -> str: """Ensure the ID is a valid string, generating a new UUID if not provided. Auto-generated UUIDs are prefixed by ``'lc_'`` to indicate they are diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 4f3608fd95a..fc8e430b287 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -25,13 +25,12 @@ import logging import math import warnings from abc import ABC, abstractmethod +from collections.abc import Callable from itertools import cycle from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Optional, TypeVar, ) @@ -62,9 +61,9 @@ class VectorStore(ABC): def add_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, *, - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -98,10 +97,10 @@ class VectorStore(ABC): ) raise ValueError(msg) metadatas_ = iter(metadatas) if metadatas else cycle([{}]) - ids_: Iterator[Optional[str]] = iter(ids) if ids else cycle([None]) + ids_: Iterator[str | None] = iter(ids) if ids else cycle([None]) docs = [ Document(id=id_, page_content=text, metadata=metadata_) - for text, metadata_, id_ in zip(texts, metadatas_, ids_) + for text, metadata_, id_ in zip(texts, metadatas_, ids_, strict=False) ] if ids is not None: # For backward compatibility @@ -112,7 +111,7 @@ class VectorStore(ABC): raise NotImplementedError(msg) @property - def embeddings(self) -> Optional[Embeddings]: + def embeddings(self) -> Embeddings | None: """Access the query embedding object if available.""" logger.debug( "The embeddings property has not been implemented for %s", @@ -120,7 +119,7 @@ class VectorStore(ABC): ) return None - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None: """Delete by vector ID or other criteria. Args: @@ -188,9 +187,7 @@ class VectorStore(ABC): """ return await run_in_executor(None, self.get_by_ids, ids) - async def adelete( - self, ids: Optional[list[str]] = None, **kwargs: Any - ) -> Optional[bool]: + async def adelete(self, ids: list[str] | None = None, **kwargs: Any) -> bool | None: """Async delete by vector ID or other criteria. Args: @@ -206,9 +203,9 @@ class VectorStore(ABC): async def aadd_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, *, - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: """Async run more texts through the embeddings and add to the vectorstore. @@ -244,11 +241,11 @@ class VectorStore(ABC): ) raise ValueError(msg) metadatas_ = iter(metadatas) if metadatas else cycle([{}]) - ids_: Iterator[Optional[str]] = iter(ids) if ids else cycle([None]) + ids_: Iterator[str | None] = iter(ids) if ids else cycle([None]) docs = [ Document(id=id_, page_content=text, metadata=metadata_) - for text, metadata_, id_ in zip(texts, metadatas_, ids_) + for text, metadata_, id_ in zip(texts, metadatas_, ids_, strict=False) ] return await self.aadd_documents(docs, **kwargs) return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs) @@ -872,9 +869,9 @@ class VectorStore(ABC): cls: type[VST], texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, *, - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> VST: """Return VectorStore initialized from texts and embeddings. @@ -896,9 +893,9 @@ class VectorStore(ABC): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, *, - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> Self: """Async return VectorStore initialized from texts and embeddings. diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index 768b92c448a..a85232ed729 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -4,12 +4,11 @@ from __future__ import annotations import json import uuid +from collections.abc import Callable from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, ) from typing_extensions import override @@ -178,20 +177,20 @@ class InMemoryVectorStore(VectorStore): return self.embedding @override - def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: + def delete(self, ids: Sequence[str] | None = None, **kwargs: Any) -> None: if ids: for _id in ids: self.store.pop(_id, None) @override - async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: + async def adelete(self, ids: Sequence[str] | None = None, **kwargs: Any) -> None: self.delete(ids) @override def add_documents( self, documents: list[Document], - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: texts = [doc.page_content for doc in documents] @@ -204,13 +203,13 @@ class InMemoryVectorStore(VectorStore): ) raise ValueError(msg) - id_iterator: Iterator[Optional[str]] = ( + id_iterator: Iterator[str | None] = ( iter(ids) if ids else iter(doc.id for doc in documents) ) ids_ = [] - for doc, vector in zip(documents, vectors): + for doc, vector in zip(documents, vectors, strict=False): doc_id = next(id_iterator) doc_id_ = doc_id or str(uuid.uuid4()) ids_.append(doc_id_) @@ -225,7 +224,7 @@ class InMemoryVectorStore(VectorStore): @override async def aadd_documents( - self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any + self, documents: list[Document], ids: list[str] | None = None, **kwargs: Any ) -> list[str]: texts = [doc.page_content for doc in documents] vectors = await self.embedding.aembed_documents(texts) @@ -237,12 +236,12 @@ class InMemoryVectorStore(VectorStore): ) raise ValueError(msg) - id_iterator: Iterator[Optional[str]] = ( + id_iterator: Iterator[str | None] = ( iter(ids) if ids else iter(doc.id for doc in documents) ) ids_: list[str] = [] - for doc, vector in zip(documents, vectors): + for doc, vector in zip(documents, vectors, strict=False): doc_id = next(id_iterator) doc_id_ = doc_id or str(uuid.uuid4()) ids_.append(doc_id_) @@ -295,7 +294,7 @@ class InMemoryVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[Callable[[Document], bool]] = None, # noqa: A002 + filter: Callable[[Document], bool] | None = None, # noqa: A002 ) -> list[tuple[Document, float, list[float]]]: # get all docs with fixed order in list docs = list(self.store.values()) @@ -338,7 +337,7 @@ class InMemoryVectorStore(VectorStore): self, embedding: list[float], k: int = 4, - filter: Optional[Callable[[Document], bool]] = None, # noqa: A002 + filter: Callable[[Document], bool] | None = None, # noqa: A002 **_kwargs: Any, ) -> list[tuple[Document, float]]: """Search for the most similar documents to the given embedding. @@ -426,7 +425,7 @@ class InMemoryVectorStore(VectorStore): fetch_k: int = 20, lambda_mult: float = 0.5, *, - filter: Optional[Callable[[Document], bool]] = None, + filter: Callable[[Document], bool] | None = None, **kwargs: Any, ) -> list[Document]: prefetch_hits = self._similarity_search_with_score_by_vector( @@ -492,7 +491,7 @@ class InMemoryVectorStore(VectorStore): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> InMemoryVectorStore: store = cls( @@ -507,7 +506,7 @@ class InMemoryVectorStore(VectorStore): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> InMemoryVectorStore: store = cls( diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 1f68269915a..ca46e638223 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging import warnings -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING try: import numpy as np @@ -25,7 +25,7 @@ except ImportError: _HAS_SIMSIMD = False if TYPE_CHECKING: - Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] + Matrix = list[list[float]] | list[np.ndarray] | np.ndarray logger = logging.getLogger(__name__) diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 23777f9291b..6c63ff00bc5 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -73,9 +73,6 @@ disallow_any_generics = false warn_return_any = false -[tool.ruff] -target-version = "py39" - [tool.ruff.format] docstring-code-format = true diff --git a/libs/core/tests/benchmarks/test_async_callbacks.py b/libs/core/tests/benchmarks/test_async_callbacks.py index 5cb58f0210e..9495f8fd2ab 100644 --- a/libs/core/tests/benchmarks/test_async_callbacks.py +++ b/libs/core/tests/benchmarks/test_async_callbacks.py @@ -1,6 +1,6 @@ import asyncio from itertools import cycle -from typing import Any, Optional, Union +from typing import Any from uuid import UUID import pytest @@ -21,9 +21,9 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: # Do nothing @@ -35,10 +35,10 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: await asyncio.sleep(0) diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 5ae0d316e6c..e2cbaa9c723 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -7,7 +7,7 @@ via snapshot testing (e.g., see unit tests for runnables). import contextvars from contextlib import asynccontextmanager -from typing import Any, Optional +from typing import Any from uuid import UUID from typing_extensions import override @@ -104,7 +104,7 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: prompts: list[str], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: if self.name == "StateModifier": diff --git a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py index 86e0aec31ff..7a6fa0f1bf4 100644 --- a/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py +++ b/libs/core/tests/unit_tests/callbacks/test_dispatch_custom_event.py @@ -1,6 +1,6 @@ import sys import uuid -from typing import Any, Optional +from typing import Any from uuid import UUID import pytest @@ -24,8 +24,8 @@ class AsyncCustomCallbackHandler(AsyncCallbackHandler): data: Any, *, run_id: UUID, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: assert kwargs == {} @@ -128,8 +128,8 @@ def test_sync_callback_manager() -> None: data: Any, *, run_id: UUID, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: assert kwargs == {} diff --git a/libs/core/tests/unit_tests/example_selectors/test_base.py b/libs/core/tests/unit_tests/example_selectors/test_base.py index c49e7745f07..0e647b69ef2 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_base.py +++ b/libs/core/tests/unit_tests/example_selectors/test_base.py @@ -1,5 +1,3 @@ -from typing import Optional - from typing_extensions import override from langchain_core.example_selectors import BaseExampleSelector @@ -7,7 +5,7 @@ from langchain_core.example_selectors import BaseExampleSelector class DummyExampleSelector(BaseExampleSelector): def __init__(self) -> None: - self.example: Optional[dict[str, str]] = None + self.example: dict[str, str] | None = None def add_example(self, example: dict[str, str]) -> None: self.example = example diff --git a/libs/core/tests/unit_tests/example_selectors/test_similarity.py b/libs/core/tests/unit_tests/example_selectors/test_similarity.py index 4d501790c78..b07870121c4 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_similarity.py +++ b/libs/core/tests/unit_tests/example_selectors/test_similarity.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -13,21 +13,21 @@ from langchain_core.vectorstores import VectorStore class DummyVectorStore(VectorStore): - def __init__(self, init_arg: Optional[str] = None): + def __init__(self, init_arg: str | None = None): self.texts: list[str] = [] self.metadatas: list[dict] = [] - self._embeddings: Optional[Embeddings] = None + self._embeddings: Embeddings | None = None self.init_arg = init_arg @property - def embeddings(self) -> Optional[Embeddings]: + def embeddings(self) -> Embeddings | None: return self._embeddings @override def add_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> list[str]: self.texts.extend(texts) @@ -66,7 +66,7 @@ class DummyVectorStore(VectorStore): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> "DummyVectorStore": store = DummyVectorStore(**kwargs) diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index b8ec1778b42..9d48d9b1d52 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -1,7 +1,7 @@ """A fake callback handler for testing purposes.""" from itertools import chain -from typing import Any, Optional, Union +from typing import Any from uuid import UUID from pydantic import BaseModel @@ -26,7 +26,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_chat_model_: bool = False # to allow for similar callback handlers that are not technically equal - fake_id: Union[str, None] = None + fake_id: str | None = None # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -288,7 +288,7 @@ class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> Any: assert all(isinstance(m, BaseMessage) for m in chain(*messages)) diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 0e1944534d0..bf5629a12c5 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -2,7 +2,7 @@ import time from itertools import cycle -from typing import Any, Optional, Union, cast +from typing import Any, cast from uuid import UUID from typing_extensions import override @@ -170,9 +170,9 @@ async def test_callback_handlers() -> None: messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, - metadata: Optional[dict[str, Any]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: # Do nothing @@ -184,10 +184,10 @@ async def test_callback_handlers() -> None: self, token: str, *, - chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: self.store.append(token) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index b80b5ce932a..9820d4ba2a8 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -3,7 +3,7 @@ import uuid import warnings from collections.abc import AsyncIterator, Iterator -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import pytest from typing_extensions import override @@ -46,7 +46,7 @@ if TYPE_CHECKING: def _content_blocks_equal_ignore_id( - actual: Union[str, list[Any]], expected: Union[str, list[Any]] + actual: str | list[Any], expected: str | list[Any] ) -> bool: """Compare content blocks, ignoring auto-generated `id` fields. @@ -63,7 +63,7 @@ def _content_blocks_equal_ignore_id( if len(actual) != len(expected): return False - for actual_block, expected_block in zip(actual, expected): + for actual_block, expected_block in zip(actual, expected, strict=False): actual_without_id = ( {k: v for k, v in actual_block.items() if k != "id"} if isinstance(actual_block, dict) and "id" in actual_block @@ -184,8 +184,8 @@ async def test_astream_fallback_to_ainvoke() -> None: def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -217,8 +217,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -228,8 +228,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the output of the model.""" @@ -269,8 +269,8 @@ async def test_astream_implementation_uses_astream() -> None: def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -280,8 +280,8 @@ async def test_astream_implementation_uses_astream() -> None: async def _astream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, # type: ignore[override] + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, # type: ignore[override] **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: """Stream the output of the model.""" @@ -350,8 +350,8 @@ class NoStreamingModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage("invoke"))]) @@ -366,8 +366,8 @@ class StreamingModel(NoStreamingModel): def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: yield ChatGenerationChunk(message=AIMessageChunk(content="stream")) @@ -376,7 +376,7 @@ class StreamingModel(NoStreamingModel): @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) def test_disable_streaming( *, - disable_streaming: Union[bool, Literal["tool_calling"]], + disable_streaming: bool | Literal["tool_calling"], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" @@ -401,7 +401,7 @@ def test_disable_streaming( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) async def test_disable_streaming_async( *, - disable_streaming: Union[bool, Literal["tool_calling"]], + disable_streaming: bool | Literal["tool_calling"], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert (await model.ainvoke([])).content == "invoke" @@ -428,7 +428,7 @@ async def test_disable_streaming_async( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) def test_disable_streaming_no_streaming_model( *, - disable_streaming: Union[bool, Literal["tool_calling"]], + disable_streaming: bool | Literal["tool_calling"], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" @@ -443,7 +443,7 @@ def test_disable_streaming_no_streaming_model( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) async def test_disable_streaming_no_streaming_model_async( *, - disable_streaming: Union[bool, Literal["tool_calling"]], + disable_streaming: bool | Literal["tool_calling"], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert (await model.ainvoke([])).content == "invoke" @@ -923,7 +923,7 @@ def test_output_version_stream(monkeypatch: Any) -> None: # v1 llm = GenericFakeChatModel(messages=iter(messages), output_version="v1") - full_v1: Optional[BaseMessageChunk] = None + full_v1: BaseMessageChunk | None = None for chunk in llm.stream("hello"): assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk.content, list) @@ -969,7 +969,7 @@ async def test_output_version_astream(monkeypatch: Any) -> None: # v1 llm = GenericFakeChatModel(messages=iter(messages), output_version="v1") - full_v1: Optional[BaseMessageChunk] = None + full_v1: BaseMessageChunk | None = None async for chunk in llm.astream("hello"): assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk.content, list) @@ -1009,8 +1009,8 @@ def test_get_ls_params() -> None: def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: raise NotImplementedError @@ -1019,8 +1019,8 @@ def test_get_ls_params() -> None: def _stream( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: raise NotImplementedError diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index b8816fa1898..b246da593fa 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -1,6 +1,6 @@ """Module tests interaction of chat model with caching abstraction..""" -from typing import Any, Optional +from typing import Any import pytest from typing_extensions import override @@ -25,7 +25,7 @@ class InMemoryCache(BaseCache): """Initialize with empty cache.""" self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Look up based on prompt and llm_string.""" return self._cache.get((prompt, llm_string), None) diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 8cf260b9878..e5547a617a7 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,5 +1,5 @@ from collections.abc import AsyncIterator, Iterator -from typing import Any, Optional +from typing import Any import pytest from typing_extensions import override @@ -111,8 +111,8 @@ async def test_error_callback() -> None: def _call( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: raise FailingLLMError @@ -142,8 +142,8 @@ async def test_astream_fallback_to_ainvoke() -> None: def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: generations = [Generation(text="hello")] @@ -168,8 +168,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Top Level call.""" @@ -179,8 +179,8 @@ async def test_astream_implementation_fallback_to_stream() -> None: def _stream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Stream the output of the model.""" @@ -206,8 +206,8 @@ async def test_astream_implementation_uses_astream() -> None: def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: """Top Level call.""" @@ -217,8 +217,8 @@ async def test_astream_implementation_uses_astream() -> None: async def _astream( self, prompt: str, - stop: Optional[list[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: """Stream the output of the model.""" @@ -244,8 +244,8 @@ def test_get_ls_params() -> None: def _generate( self, prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> LLMResult: raise NotImplementedError diff --git a/libs/core/tests/unit_tests/language_models/llms/test_cache.py b/libs/core/tests/unit_tests/language_models/llms/test_cache.py index e8d489c055d..a0bd8a34b33 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_cache.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from typing_extensions import override @@ -14,7 +14,7 @@ class InMemoryCache(BaseCache): """Initialize with empty cache.""" self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Look up based on prompt and llm_string.""" return self._cache.get((prompt, llm_string), None) @@ -67,7 +67,7 @@ class InMemoryCacheBad(BaseCache): """Initialize with empty cache.""" self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} - def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + def lookup(self, prompt: str, llm_string: str) -> RETURN_VAL_TYPE | None: """Look up based on prompt and llm_string.""" msg = "This code should not be triggered" raise NotImplementedError(msg) diff --git a/libs/core/tests/unit_tests/messages/block_translators/test_anthropic.py b/libs/core/tests/unit_tests/messages/block_translators/test_anthropic.py index 6f7e0c94fe4..bae653b1b4e 100644 --- a/libs/core/tests/unit_tests/messages/block_translators/test_anthropic.py +++ b/libs/core/tests/unit_tests/messages/block_translators/test_anthropic.py @@ -1,5 +1,3 @@ -from typing import Optional - from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_core.messages import content as types @@ -293,10 +291,10 @@ def test_convert_to_v1_from_anthropic_chunk() -> None: "type": "tool_call_chunk", }, ] - for chunk, expected in zip(chunks, expected_contents): + for chunk, expected in zip(chunks, expected_contents, strict=False): assert chunk.content_blocks == [expected] - full: Optional[AIMessageChunk] = None + full: AIMessageChunk | None = None for chunk in chunks: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) diff --git a/libs/core/tests/unit_tests/messages/block_translators/test_bedrock.py b/libs/core/tests/unit_tests/messages/block_translators/test_bedrock.py index 6278abdee5f..7e34ee3b471 100644 --- a/libs/core/tests/unit_tests/messages/block_translators/test_bedrock.py +++ b/libs/core/tests/unit_tests/messages/block_translators/test_bedrock.py @@ -1,5 +1,3 @@ -from typing import Optional - from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_core.messages import content as types @@ -254,10 +252,10 @@ def test_convert_to_v1_from_bedrock_chunk() -> None: "type": "tool_call_chunk", }, ] - for chunk, expected in zip(chunks, expected_contents): + for chunk, expected in zip(chunks, expected_contents, strict=False): assert chunk.content_blocks == [expected] - full: Optional[AIMessageChunk] = None + full: AIMessageChunk | None = None for chunk in chunks: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) diff --git a/libs/core/tests/unit_tests/messages/block_translators/test_bedrock_converse.py b/libs/core/tests/unit_tests/messages/block_translators/test_bedrock_converse.py index 31977ca74b8..e1dc5d6b554 100644 --- a/libs/core/tests/unit_tests/messages/block_translators/test_bedrock_converse.py +++ b/libs/core/tests/unit_tests/messages/block_translators/test_bedrock_converse.py @@ -1,5 +1,3 @@ -from typing import Optional - from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_core.messages import content as types @@ -250,10 +248,10 @@ def test_convert_to_v1_from_converse_chunk() -> None: "type": "tool_call_chunk", }, ] - for chunk, expected in zip(chunks, expected_contents): + for chunk, expected in zip(chunks, expected_contents, strict=False): assert chunk.content_blocks == [expected] - full: Optional[AIMessageChunk] = None + full: AIMessageChunk | None = None for chunk in chunks: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) diff --git a/libs/core/tests/unit_tests/messages/block_translators/test_openai.py b/libs/core/tests/unit_tests/messages/block_translators/test_openai.py index 2a5579506b1..7d575b57944 100644 --- a/libs/core/tests/unit_tests/messages/block_translators/test_openai.py +++ b/libs/core/tests/unit_tests/messages/block_translators/test_openai.py @@ -1,5 +1,3 @@ -from typing import Optional - import pytest from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage @@ -218,10 +216,10 @@ def test_convert_to_v1_from_responses_chunk() -> None: response_metadata={"model_provider": "openai"}, ), ] - for chunk, expected in zip(chunks, expected_chunks): + for chunk, expected in zip(chunks, expected_chunks, strict=False): assert chunk.content_blocks == expected.content_blocks - full: Optional[AIMessageChunk] = None + full: AIMessageChunk | None = None for chunk in chunks: full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index d0a57a57192..d18a1e9b0a3 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -1,4 +1,4 @@ -from typing import Union, cast +from typing import cast import pytest @@ -357,9 +357,9 @@ def test_content_blocks() -> None: assert chunk.content_blocks == chunk.tool_calls # test v1 content - chunk_1.content = cast("Union[str, list[Union[str, dict]]]", chunk_1.content_blocks) + chunk_1.content = cast("str | list[str | dict]", chunk_1.content_blocks) chunk_1.response_metadata["output_version"] = "v1" - chunk_2.content = cast("Union[str, list[Union[str, dict]]]", chunk_2.content_blocks) + chunk_2.content = cast("str | list[str | dict]", chunk_2.content_blocks) chunk = chunk_1 + chunk_2 + chunk_3 assert chunk.content == [ @@ -378,12 +378,8 @@ def test_content_blocks() -> None: standard_content_2: list[types.ContentBlock] = [ {"type": "non_standard", "index": 0, "value": {"foo": "baz"}} ] - chunk_1 = AIMessageChunk( - content=cast("Union[str, list[Union[str, dict]]]", standard_content_1) - ) - chunk_2 = AIMessageChunk( - content=cast("Union[str, list[Union[str, dict]]]", standard_content_2) - ) + chunk_1 = AIMessageChunk(content=cast("str | list[str | dict]", standard_content_1)) + chunk_2 = AIMessageChunk(content=cast("str | list[str | dict]", standard_content_2)) merged_chunk = chunk_1 + chunk_2 assert merged_chunk.content == [ {"type": "non_standard", "index": 0, "value": {"foo": "bar baz"}}, @@ -470,12 +466,8 @@ def test_content_blocks() -> None: } ] standard_content_2 = [{"type": "non_standard", "value": {"foo": "bar"}, "index": 0}] - chunk_1 = AIMessageChunk( - content=cast("Union[str, list[Union[str, dict]]]", standard_content_1) - ) - chunk_2 = AIMessageChunk( - content=cast("Union[str, list[Union[str, dict]]]", standard_content_2) - ) + chunk_1 = AIMessageChunk(content=cast("str | list[str | dict]", standard_content_1)) + chunk_2 = AIMessageChunk(content=cast("str | list[str | dict]", standard_content_2)) merged_chunk = chunk_1 + chunk_2 assert merged_chunk.content == [ { diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 32151132b3d..6876b23f5b8 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,8 +1,8 @@ import base64 import json import re -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import pytest from typing_extensions import override @@ -665,9 +665,7 @@ class FakeTokenCountingModel(FakeChatModel): def get_num_tokens_from_messages( self, messages: list[BaseMessage], - tools: Optional[ - Sequence[Union[dict[str, Any], type, Callable, BaseTool]] - ] = None, + tools: Sequence[dict[str, Any] | type | Callable | BaseTool] | None = None, ) -> int: return dummy_token_counter(messages) diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index edcf72e4f0f..7af420b326e 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -884,5 +884,5 @@ def test_max_tokens_error(caplog: Any) -> None: _ = parser.invoke(message) assert any( "`max_tokens` stop reason" in msg and record.levelname == "ERROR" - for record, msg in zip(caplog.records, caplog.messages) + for record, msg in zip(caplog.records, caplog.messages, strict=False) ) diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 07acf94d7c2..9bf3bd19489 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -1,7 +1,7 @@ """Test PydanticOutputParser.""" from enum import Enum -from typing import Literal, Optional, Union +from typing import Literal import pydantic import pytest @@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel): @pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) def test_pydantic_parser_chaining( - pydantic_object: Union[type[ForecastV2], type[ForecastV1]], + pydantic_object: type[ForecastV2] | type[ForecastV1], ) -> None: prompt = PromptTemplate( template="""{{ @@ -109,9 +109,7 @@ class Actions(Enum): class TestModel(BaseModel): action: Actions = Field(description="Action to be performed") action_input: str = Field(description="Input to be used in the action") - additional_fields: Optional[str] = Field( - description="Additional fields", default=None - ) + additional_fields: str | None = Field(description="Additional fields", default=None) for_new_lines: str = Field(description="To be used to test newlines") diff --git a/libs/core/tests/unit_tests/outputs/test_chat_generation.py b/libs/core/tests/unit_tests/outputs/test_chat_generation.py index c409a76f0de..5cf7b0db9bd 100644 --- a/libs/core/tests/unit_tests/outputs/test_chat_generation.py +++ b/libs/core/tests/unit_tests/outputs/test_chat_generation.py @@ -1,5 +1,3 @@ -from typing import Union - import pytest from langchain_core.messages import AIMessage @@ -19,14 +17,14 @@ from langchain_core.outputs import ChatGeneration ], ], ) -def test_msg_with_text(content: Union[str, list]) -> None: +def test_msg_with_text(content: str | list) -> None: expected = "foo" actual = ChatGeneration(message=AIMessage(content=content)).text assert actual == expected @pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]]) -def test_msg_no_text(content: Union[str, list]) -> None: +def test_msg_no_text(content: str | list) -> None: expected = "" actual = ChatGeneration(message=AIMessage(content=content)).text assert actual == expected diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 1b0926d8743..a5a09c00f01 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,7 +1,7 @@ import re import warnings from pathlib import Path -from typing import Any, Union, cast +from typing import Any, cast import pytest from packaging import version @@ -481,7 +481,7 @@ def test_chat_valid_infer_variables() -> None: ], ) def test_convert_to_message( - args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate] + args: Any, expected: BaseMessage | BaseMessagePromptTemplate ) -> None: """Test convert to message.""" assert _convert_to_message_template(args) == expected diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 64195302129..0540fbd997f 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -2,7 +2,7 @@ import re from tempfile import NamedTemporaryFile -from typing import Any, Literal, Union +from typing import Any, Literal from unittest import mock import pytest @@ -627,7 +627,7 @@ async def test_prompt_ainvoke_with_metadata() -> None: def test_prompt_falsy_vars( template_format: PromptTemplateFormat, value: Any, - expected: Union[str, dict[str, str]], + expected: str | dict[str, str], ) -> None: # each line is value, f-string, mustache if template_format == "f-string": diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 9758f1d1d1e..44c6a215ba9 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -1,6 +1,6 @@ from functools import partial from inspect import isclass -from typing import Any, Union, cast +from typing import Any, cast import pytest from pydantic import BaseModel @@ -17,8 +17,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass def _fake_runnable( - _: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any -) -> Union[BaseModel, dict]: + _: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any +) -> BaseModel | dict: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) params = cast("dict", schema)["parameters"] @@ -30,7 +30,7 @@ class FakeStructuredChatModel(FakeListChatModel): @override def with_structured_output( - self, schema: Union[dict, type[BaseModel]], **kwargs: Any + self, schema: dict | type[BaseModel], **kwargs: Any ) -> Runnable: return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs)) diff --git a/libs/core/tests/unit_tests/pydantic_utils.py b/libs/core/tests/unit_tests/pydantic_utils.py index 8e7a9a078d7..2c8036a129d 100644 --- a/libs/core/tests/unit_tests/pydantic_utils.py +++ b/libs/core/tests/unit_tests/pydantic_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from pydantic import BaseModel @@ -105,7 +105,7 @@ def _remove_additionalproperties_from_untyped_dicts(schema: dict) -> dict[str, A """ def _remove_dict_additional_props( - obj: Union[dict[str, Any], list[Any]], *, inside_properties: bool = False + obj: dict[str, Any] | list[Any], *, inside_properties: bool = False ) -> None: if isinstance(obj, dict): if ( diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 86bf14adf16..f397e249d9a 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pytest from pydantic import ConfigDict, Field, model_validator @@ -34,7 +34,7 @@ class MyRunnable(RunnableSerializable[str, str]): @override def invoke( - self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: str, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: return input + self._my_hidden_property @@ -43,14 +43,14 @@ class MyRunnable(RunnableSerializable[str, str]): def my_custom_function_w_config( self, - config: Optional[RunnableConfig] = None, # noqa: ARG002 + config: RunnableConfig | None = None, # noqa: ARG002 ) -> str: return self.my_property def my_custom_function_w_kw_config( self, *, - config: Optional[RunnableConfig] = None, # noqa: ARG002 + config: RunnableConfig | None = None, # noqa: ARG002 ) -> str: return self.my_property @@ -60,7 +60,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]): @override def invoke( - self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: str, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: return input + self.my_other_property diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 31efb3babb7..0a6b82540e2 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -1,9 +1,6 @@ -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from typing import ( Any, - Callable, - Optional, - Union, ) import pytest @@ -118,7 +115,7 @@ def _runnable(inputs: dict) -> str: def _assert_potential_error(actual: list, expected: list) -> None: - for x, y in zip(actual, expected): + for x, y in zip(actual, expected, strict=False): if isinstance(x, Exception): assert isinstance(y, type(x)) else: @@ -328,8 +325,8 @@ class FakeStructuredOutputModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -338,15 +335,15 @@ class FakeStructuredOutputModel(BaseChatModel): @override def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: return self.bind(tools=tools) @override def with_structured_output( - self, schema: Union[dict, type[BaseModel]], **kwargs: Any - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + self, schema: dict | type[BaseModel], **kwargs: Any + ) -> Runnable[LanguageModelInput, dict | BaseModel]: return RunnableLambda(lambda _: {"foo": self.foo}) @property @@ -361,8 +358,8 @@ class FakeModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -371,7 +368,7 @@ class FakeModel(BaseChatModel): @override def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: return self.bind(tools=tools) diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 52ac05a397f..44774c716c4 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from unittest.mock import MagicMock, patch from packaging import version @@ -484,7 +484,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None: def invoke( self, input: int, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> int: return input @@ -509,7 +509,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: def invoke( self, input: int, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> int: return input diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index daba3e55749..9ede918fa2d 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -1,6 +1,6 @@ import re -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import pytest from packaging import version @@ -39,7 +39,7 @@ def test_interfaces() -> None: def _get_get_session_history( *, - store: Optional[dict[str, Any]] = None, + store: dict[str, Any] | None = None, ) -> Callable[..., InMemoryChatMessageHistory]: chat_history_store = store if store is not None else {} @@ -262,8 +262,8 @@ class LengthChatModel(BaseChatModel): def _generate( self, messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: """Top Level call.""" @@ -432,7 +432,7 @@ async def test_output_dict_async() -> None: def test_get_input_schema_input_dict() -> None: class RunnableWithChatHistoryInput(BaseModel): - input: Union[str, BaseMessage, Sequence[BaseMessage]] + input: str | BaseMessage | Sequence[BaseMessage] runnable = RunnableLambda( lambda params: { @@ -781,15 +781,15 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): def with_listeners( self, *, - on_start: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_end: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, - on_error: Optional[ - Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] - ] = None, + on_start: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_end: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, + on_error: Callable[[Run], None] + | Callable[[Run, RunnableConfig], None] + | None = None, ) -> Runnable[Input, Output]: def create_tracer(config: RunnableConfig) -> RunnableConfig: tracer = RootListenersTracer( @@ -811,9 +811,9 @@ class _RunnableLambdaWithRaiseError(RunnableLambda[Input, Output]): def with_alisteners( self, *, - on_start: Optional[AsyncListener] = None, - on_end: Optional[AsyncListener] = None, - on_error: Optional[AsyncListener] = None, + on_start: AsyncListener | None = None, + on_end: AsyncListener | None = None, + on_error: AsyncListener | None = None, ) -> Runnable[Input, Output]: def create_tracer(config: RunnableConfig) -> RunnableConfig: tracer = AsyncRootListenersTracer( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 93767f22c11..63cc478f404 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -4,10 +4,10 @@ import sys import time import uuid import warnings -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from functools import partial from operator import itemgetter -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast from uuid import UUID import pytest @@ -170,7 +170,7 @@ class FakeTracer(BaseTracer): return result @property - def run_ids(self) -> list[Optional[uuid.UUID]]: + def run_ids(self) -> list[uuid.UUID | None]: runs = self.flattened_runs() uuids_map = {v: k for k, v in self.uuids_map.items()} return [uuids_map.get(r.id) for r in runs] @@ -181,7 +181,7 @@ class FakeRunnable(Runnable[str, int]): def invoke( self, input: str, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> int: return len(input) @@ -194,7 +194,7 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]): def invoke( self, input: str, - config: Optional[RunnableConfig] = None, + config: RunnableConfig | None = None, **kwargs: Any, ) -> int: return len(input) @@ -3839,7 +3839,7 @@ def test_each(snapshot: SnapshotAssertion) -> None: def test_recursive_lambda() -> None: - def _simple_recursion(x: int) -> Union[int, Runnable]: + def _simple_recursion(x: int) -> int | Runnable: if x < 10: return RunnableLambda(lambda *_: _simple_recursion(x + 1)) return x @@ -3852,7 +3852,7 @@ def test_recursive_lambda() -> None: def test_retrying(mocker: MockerFixture) -> None: - def _lambda(x: int) -> Union[int, Runnable]: + def _lambda(x: int) -> int | Runnable: if x == 1: msg = "x is 1" raise ValueError(msg) @@ -3969,7 +3969,7 @@ async def test_async_retry_batch_preserves_order() -> None: async def test_async_retrying(mocker: MockerFixture) -> None: - def _lambda(x: int) -> Union[int, Runnable]: + def _lambda(x: int) -> int | Runnable: if x == 1: msg = "x is 1" raise ValueError(msg) @@ -4170,7 +4170,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: @override def invoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Any, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: raise NotImplementedError @@ -4194,7 +4194,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: def batch( self, inputs: list[str], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -4311,7 +4311,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: @override def invoke( - self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Any, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: raise NotImplementedError @@ -4335,7 +4335,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: async def abatch( self, inputs: list[str], - config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, + config: RunnableConfig | list[RunnableConfig] | None = None, *, return_exceptions: bool = False, **kwargs: Any, @@ -5431,7 +5431,7 @@ async def test_astream_log_deep_copies() -> None: chain = RunnableLambda(add_one) chunks = [] - final_output: Optional[RunLogPatch] = None + final_output: RunLogPatch | None = None async for chunk in chain.astream_log(1): chunks.append(chunk) final_output = chunk if final_output is None else final_output + chunk @@ -5497,7 +5497,7 @@ def test_default_transform_with_dicts() -> None: class CustomRunnable(RunnableSerializable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: return cast("Output", input) @@ -5519,7 +5519,7 @@ async def test_default_atransform_with_dicts() -> None: class CustomRunnable(RunnableSerializable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: return cast("Output", input) @@ -5631,8 +5631,8 @@ def test_closing_iterator_doesnt_raise_error() -> None: error: BaseException, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[list[str]] = None, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, **kwargs: Any, ) -> None: """Run when chain errors.""" @@ -5645,7 +5645,7 @@ def test_closing_iterator_doesnt_raise_error() -> None: outputs: dict[str, Any], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: nonlocal on_chain_end_triggered diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 2dc16821f2b..8efd5aea9a6 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -62,7 +62,7 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEven def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None: """Assert that the events are equal.""" assert len(events) == len(expected) - for i, (event, expected_event) in enumerate(zip(events, expected)): + for i, (event, expected_event) in enumerate(zip(events, expected, strict=False)): # we want to allow a superset of metadata on each event_with_edited_metadata = { k: ( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 3660c5d6923..de6c4030cac 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -4,13 +4,11 @@ import asyncio import inspect import sys import uuid -from collections.abc import AsyncIterator, Iterable, Iterator, Sequence +from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from functools import partial from itertools import cycle from typing import ( Any, - Callable, - Optional, cast, ) @@ -362,7 +360,7 @@ async def test_event_stream_with_triple_lambda() -> None: async def test_event_stream_exception() -> None: - def step(name: str, err: Optional[str], val: str) -> str: + def step(name: str, err: str | None, val: str) -> str: if err: raise ValueError(err) return val + name[-1] @@ -2135,7 +2133,7 @@ class StreamingRunnable(Runnable[Input, Output]): @override def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Output: """Invoke the runnable.""" msg = "Server side error" @@ -2145,8 +2143,8 @@ class StreamingRunnable(Runnable[Input, Output]): def stream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> Iterator[Output]: raise NotImplementedError @@ -2154,8 +2152,8 @@ class StreamingRunnable(Runnable[Input, Output]): async def astream( self, input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], + config: RunnableConfig | None = None, + **kwargs: Any | None, ) -> AsyncIterator[Output]: config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 3af1a16c0a3..e5638b3e8ed 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import sys import uuid -from collections.abc import AsyncGenerator, Coroutine, Generator +from collections.abc import AsyncGenerator, Callable, Coroutine, Generator from inspect import isasyncgenfunction -from typing import Any, Callable, Literal, Optional +from typing import Any, Literal from unittest.mock import MagicMock, patch import pytest @@ -39,8 +39,8 @@ def _get_posts(client: Client) -> list: def _create_tracer_with_mocked_client( - project_name: Optional[str] = None, - tags: Optional[list[str]] = None, + project_name: str | None = None, + tags: list[str] | None = None, ) -> LangChainTracer: mock_session = MagicMock() mock_client_ = Client( @@ -225,7 +225,7 @@ async def test_config_traceable_async_handoff() -> None: @pytest.mark.parametrize("enabled", [None, True, False]) @pytest.mark.parametrize("env", ["", "true"]) def test_tracing_enable_disable( - mock_get_client: MagicMock, *, enabled: Optional[bool], env: str + mock_get_client: MagicMock, *, enabled: bool | None, env: str ) -> None: mock_session = MagicMock() mock_client_ = Client( diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index 5733901c820..7eec2c52a7a 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import pytest diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 5d743ca0133..ff8183e5277 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,5 +1,5 @@ import uuid -from typing import Optional, Union, get_args +from typing import get_args import pytest @@ -639,7 +639,7 @@ def test_tool_calls_merge() -> None: {"content": ""}, ] - final: Optional[BaseMessageChunk] = None + final: BaseMessageChunk | None = None for chunk in chunks: msg = AIMessageChunk(**chunk) @@ -1015,9 +1015,7 @@ def test_tool_message_str() -> None: ), ], ) -def test_merge_content( - first: Union[list, str], others: list, expected: Union[list, str] -) -> None: +def test_merge_content(first: list | str, others: list, expected: list | str) -> None: actual = merge_content(first, *others) assert actual == expected diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 90574109204..bf98a980eee 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -5,18 +5,16 @@ import json import sys import textwrap import threading +from collections.abc import Callable from datetime import datetime from enum import Enum from functools import partial from typing import ( Annotated, Any, - Callable, Generic, Literal, - Optional, TypeVar, - Union, cast, ) @@ -104,7 +102,7 @@ class _MockSchema(BaseModel): arg1: int arg2: bool - arg3: Optional[dict] = None + arg3: dict | None = None class _MockSchemaV1(BaseModelV1): @@ -112,7 +110,7 @@ class _MockSchemaV1(BaseModelV1): arg1: int arg2: bool - arg3: Optional[dict] = None + arg3: dict | None = None class _MockStructuredTool(BaseTool): @@ -121,10 +119,10 @@ class _MockStructuredTool(BaseTool): description: str = "A Structured Tool" @override - def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" - async def _arun(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + async def _arun(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str: raise NotImplementedError @@ -149,13 +147,11 @@ def test_misannotated_base_tool_raises_error() -> None: description: str = "A Structured Tool" @override - def _run( - self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None - ) -> str: + def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" async def _arun( - self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None + self, *, arg1: int, arg2: bool, arg3: dict | None = None ) -> str: raise NotImplementedError @@ -169,11 +165,11 @@ def test_forward_ref_annotated_base_tool_accepted() -> None: description: str = "A Structured Tool" @override - def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" async def _arun( - self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None + self, *, arg1: int, arg2: bool, arg3: dict | None = None ) -> str: raise NotImplementedError @@ -187,11 +183,11 @@ def test_subclass_annotated_base_tool_accepted() -> None: description: str = "A Structured Tool" @override - def _run(self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def _run(self, *, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" async def _arun( - self, *, arg1: int, arg2: bool, arg3: Optional[dict] = None + self, *, arg1: int, arg2: bool, arg3: dict | None = None ) -> str: raise NotImplementedError @@ -204,14 +200,14 @@ def test_decorator_with_specified_schema() -> None: """Test that manually specified schemata are passed through to the tool.""" @tool(args_schema=_MockSchema) - def tool_func(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def tool_func(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" assert isinstance(tool_func, BaseTool) assert tool_func.args_schema == _MockSchema @tool(args_schema=cast("ArgsSchema", _MockSchemaV1)) - def tool_func_v1(*, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: + def tool_func_v1(*, arg1: int, arg2: bool, arg3: dict | None = None) -> str: return f"{arg1} {arg2} {arg3}" assert isinstance(tool_func_v1, BaseTool) @@ -223,7 +219,7 @@ def test_decorated_function_schema_equivalent() -> None: @tool def structured_tool_input( - *, arg1: int, arg2: bool, arg3: Optional[dict] = None + *, arg1: int, arg2: bool, arg3: dict | None = None ) -> str: """Return the arguments directly.""" return f"{arg1} {arg2} {arg3}" @@ -246,7 +242,7 @@ def test_args_kwargs_filtered() -> None: def _run( self, some_arg: str, - run_manager: Optional[CallbackManagerForToolRun] = None, + run_manager: CallbackManagerForToolRun | None = None, **kwargs: Any, ) -> str: return "foo" @@ -254,7 +250,7 @@ def test_args_kwargs_filtered() -> None: async def _arun( self, some_arg: str, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + run_manager: AsyncCallbackManagerForToolRun | None = None, **kwargs: Any, ) -> str: raise NotImplementedError @@ -270,7 +266,7 @@ def test_args_kwargs_filtered() -> None: def _run( self, *args: Any, - run_manager: Optional[CallbackManagerForToolRun] = None, + run_manager: CallbackManagerForToolRun | None = None, **kwargs: Any, ) -> str: return "foo" @@ -278,7 +274,7 @@ def test_args_kwargs_filtered() -> None: async def _arun( self, *args: Any, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + run_manager: AsyncCallbackManagerForToolRun | None = None, **kwargs: Any, ) -> str: raise NotImplementedError @@ -292,7 +288,7 @@ def test_structured_args_decorator_no_infer_schema() -> None: @tool(infer_schema=False) def structured_tool_input( - arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None + arg1: int, arg2: float | datetime, opt_arg: dict | None = None ) -> str: """Return the arguments directly.""" return f"{arg1}, {arg2}, {opt_arg}" @@ -553,7 +549,7 @@ def test_empty_args_decorator() -> None: def test_tool_from_function_with_run_manager() -> None: """Test run of tool when using run_manager.""" - def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str: # noqa: D417 + def foo(bar: str, callbacks: CallbackManagerForToolRun | None = None) -> str: # noqa: D417 """Docstring. Args: @@ -573,7 +569,7 @@ def test_structured_tool_from_function_with_run_manager() -> None: """Test args and schema of structured tool when using callbacks.""" def foo( # noqa: D417 - bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None + bar: int, baz: str, callbacks: CallbackManagerForToolRun | None = None ) -> str: """Docstring. @@ -883,7 +879,7 @@ def test_validation_error_handling_callable() -> None: """Test that validation errors are handled correctly.""" expected = "foo bar" - def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: + def handling(e: ValidationError | ValidationErrorV1) -> str: return expected tool_ = _MockStructuredTool(handle_validation_error=handling) @@ -901,9 +897,7 @@ def test_validation_error_handling_callable() -> None: ) def test_validation_error_handling_non_validation_error( *, - handler: Union[ - bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str] - ], + handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str], ) -> None: """Test that validation errors are handled correctly.""" @@ -913,9 +907,9 @@ def test_validation_error_handling_non_validation_error( def _parse_input( self, - tool_input: Union[str, dict], - tool_call_id: Optional[str], - ) -> Union[str, dict[str, Any]]: + tool_input: str | dict, + tool_call_id: str | None, + ) -> str | dict[str, Any]: raise NotImplementedError @override @@ -951,7 +945,7 @@ async def test_async_validation_error_handling_callable() -> None: """Test that validation errors are handled correctly.""" expected = "foo bar" - def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: + def handling(e: ValidationError | ValidationErrorV1) -> str: return expected tool_ = _MockStructuredTool(handle_validation_error=handling) @@ -969,9 +963,7 @@ async def test_async_validation_error_handling_callable() -> None: ) async def test_async_validation_error_handling_non_validation_error( *, - handler: Union[ - bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str] - ], + handler: bool | str | Callable[[ValidationError | ValidationErrorV1], str], ) -> None: """Test that validation errors are handled correctly.""" @@ -981,9 +973,9 @@ async def test_async_validation_error_handling_non_validation_error( def _parse_input( self, - tool_input: Union[str, dict], - tool_call_id: Optional[str], - ) -> Union[str, dict[str, Any]]: + tool_input: str | dict, + tool_call_id: str | None, + ) -> str | dict[str, Any]: raise NotImplementedError @override @@ -1001,9 +993,9 @@ async def test_async_validation_error_handling_non_validation_error( def test_optional_subset_model_rewrite() -> None: class MyModel(BaseModel): - a: Optional[str] = None + a: str | None = None b: str - c: Optional[list[Optional[str]]] = None + c: list[str | None] | None = None model2 = _create_subset_model("model2", MyModel, ["a", "b", "c"]) @@ -1028,9 +1020,9 @@ def test_optional_subset_model_rewrite() -> None: ({"bar": "bar", "baz": None}, {"bar": "bar", "baz": None, "buzz": "buzz"}), ], ) -def test_tool_invoke_optional_args(inputs: dict, expected: Optional[dict]) -> None: +def test_tool_invoke_optional_args(inputs: dict, expected: dict | None) -> None: @tool - def foo(bar: str, baz: Optional[int] = 3, buzz: Optional[str] = "buzz") -> dict: + def foo(bar: str, baz: int | None = 3, buzz: str | None = "buzz") -> dict: """The foo.""" return { "bar": bar, @@ -1217,7 +1209,7 @@ def test_tool_arg_descriptions() -> None: # Test parsing with run_manager does not raise error def foo3( # noqa: D417 - bar: str, baz: int, run_manager: Optional[CallbackManagerForToolRun] = None + bar: str, baz: int, run_manager: CallbackManagerForToolRun | None = None ) -> str: """The foo. @@ -1242,7 +1234,7 @@ def test_tool_arg_descriptions() -> None: args_schema = _schema(as_tool.args_schema) assert args_schema["description"] == expected["description"] - def foo5(run_manager: Optional[CallbackManagerForToolRun] = None) -> str: + def foo5(run_manager: CallbackManagerForToolRun | None = None) -> str: """The foo.""" return "bar" @@ -1448,14 +1440,14 @@ class _MockStructuredToolWithRawOutput(BaseTool): self, arg1: int, arg2: bool, - arg3: Optional[dict] = None, + arg3: dict | None = None, ) -> tuple[str, dict]: return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @tool("structured_api", response_format="content_and_artifact") def _mock_structured_tool_with_artifact( - *, arg1: int, arg2: bool, arg3: Optional[dict] = None + *, arg1: int, arg2: bool, arg3: dict | None = None ) -> tuple[str, dict]: """A Structured Tool.""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @@ -2124,16 +2116,16 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: actual = get_all_basemodel_annotations(ModelA[int]) assert actual == expected - D = TypeVar("D", bound=Union[str, int]) + D = TypeVar("D", bound=str | int) class ModelD(ModelC, Generic[D]): - d: Optional[D] + d: D | None expected = { "a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, - "d": Union[str, int, None], + "d": str | int | None, } actual = get_all_basemodel_annotations(ModelD) assert actual == expected @@ -2142,7 +2134,7 @@ def test__get_all_basemodel_annotations_v2(*, use_v1_namespace: bool) -> None: "a": str, "b": Annotated[ModelA[dict[str, Any]], "foo"], "c": dict, - "d": Union[int, None], + "d": int | None, } actual = get_all_basemodel_annotations(ModelD[int]) assert actual == expected @@ -2430,7 +2422,7 @@ def test_tool_mutate_input() -> None: def _run( self, x: str, - run_manager: Optional[CallbackManagerForToolRun] = None, + run_manager: CallbackManagerForToolRun | None = None, ) -> str: return "hi" diff --git a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py index 1b243c03816..67d83d210e9 100644 --- a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py @@ -35,7 +35,9 @@ class FakeAsyncTracer(AsyncBaseTracer): def _compare_run_with_error(run: Any, expected_run: Any) -> None: if run.child_runs: assert len(expected_run.child_runs) == len(run.child_runs) - for received, expected in zip(run.child_runs, expected_run.child_runs): + for received, expected in zip( + run.child_runs, expected_run.child_runs, strict=False + ): _compare_run_with_error(received, expected) received = run.dict(exclude={"child_runs"}) received_err = received.pop("error") diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index aaa34a662f2..119b2907773 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -40,7 +40,9 @@ class FakeTracer(BaseTracer): def _compare_run_with_error(run: Any, expected_run: Any) -> None: if run.child_runs: assert len(expected_run.child_runs) == len(run.child_runs) - for received, expected in zip(run.child_runs, expected_run.child_runs): + for received, expected in zip( + run.child_runs, expected_run.child_runs, strict=False + ): _compare_run_with_error(received, expected) received = run.dict(exclude={"child_runs"}) received_err = received.pop("error") diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 0710c9a5a27..60747910589 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -1,19 +1,16 @@ import typing -from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence from typing import Annotated as ExtensionsAnnotated from typing import ( Any, - Callable, Literal, - Optional, - Union, + TypeAlias, ) from typing import TypedDict as TypingTypedDict import pytest from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore from pydantic import Field as FieldV2Maybe # pydantic: ignore -from typing_extensions import TypeAlias from typing_extensions import TypedDict as ExtensionsTypedDict try: @@ -510,7 +507,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None: class NestedC(BaseModel): baz: bool - def my_function(my_arg: Union[NestedA, NestedB, NestedC]) -> None: + def my_function(my_arg: NestedA | NestedB | NestedC) -> None: """Dummy function.""" expected = { @@ -685,9 +682,9 @@ def test_convert_to_openai_function_no_description_no_params(func: dict) -> None def test_function_optional_param() -> None: @tool def func5( - a: Optional[str], + a: str | None, b: str, - c: Optional[list[Optional[str]]], + c: list[str | None] | None, ) -> None: """A test function.""" @@ -820,12 +817,12 @@ def test__convert_typed_dict_to_openai_function( """ arg1: str - arg2: Union[int, str, bool] - arg3: Optional[list[SubTool]] + arg2: int | str | bool + arg3: list[SubTool] | None arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 - arg5: annotated[Optional[float], None] + arg5: annotated[float | None, None] arg6: annotated[ - Optional[Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]]], [] + Sequence[Mapping[str, tuple[Iterable[Any], SubTool]]] | None, [] ] arg7: annotated[list[SubTool], ...] arg8: annotated[tuple[SubTool], ...] @@ -1052,7 +1049,7 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None: def test_convert_union_type() -> None: @tool - def magic_function(value: int | str) -> str: # noqa: ARG001,FA102 + def magic_function(value: int | str) -> str: # noqa: ARG001 """Compute a magic function.""" return "" @@ -1125,7 +1122,7 @@ def test_convert_to_json_schema( def test_convert_to_openai_function_nested_strict_2() -> None: - def my_function(arg1: dict, arg2: Union[dict, None]) -> None: + def my_function(arg1: dict, arg2: dict | None) -> None: """Dummy function.""" expected: dict = { diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index 3afeca68c73..0793c752108 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -1,7 +1,7 @@ """Test for some custom pydantic decorators.""" import warnings -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field from pydantic.v1 import BaseModel as BaseModelV1 @@ -36,7 +36,7 @@ def test_pre_init_decorator() -> None: def test_pre_init_decorator_with_more_defaults() -> None: class Foo(BaseModel): a: int = 1 - b: Optional[int] = None + b: int | None = None c: int = Field(default=2) d: int = Field(default_factory=lambda: 3) diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 1e6a88b6375..e01103688a8 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -1,8 +1,9 @@ import os import re +from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any from unittest.mock import patch import pytest @@ -34,9 +35,9 @@ from langchain_core.utils.utils import secret_from_env ) def test_check_package_version( package: str, - check_kwargs: dict[str, Optional[str]], + check_kwargs: dict[str, str | None], actual_version: str, - expected: Optional[tuple[type[Exception], str]], + expected: tuple[type[Exception], str] | None, ) -> None: with patch("langchain_core.utils.utils.version", return_value=actual_version): if expected is None: @@ -116,7 +117,7 @@ def test_check_package_version( ], ) def test_merge_dicts( - left: dict, right: dict, expected: Union[dict, AbstractContextManager] + left: dict, right: dict, expected: dict | AbstractContextManager ) -> None: err = expected if isinstance(expected, AbstractContextManager) else nullcontext() @@ -144,7 +145,7 @@ def test_merge_dicts( ) @pytest.mark.xfail(reason="Refactors to make in 0.3") def test_merge_dicts_0_3( - left: dict, right: dict, expected: Union[dict, AbstractContextManager] + left: dict, right: dict, expected: dict | AbstractContextManager ) -> None: err = expected if isinstance(expected, AbstractContextManager) else nullcontext() @@ -168,7 +169,7 @@ def test_merge_dicts_0_3( ], ) def test_guard_import( - module_name: str, pip_name: Optional[str], package: Optional[str], expected: Any + module_name: str, pip_name: str | None, package: str | None, expected: Any ) -> None: if package is None and pip_name is None: ret = guard_import(module_name) @@ -201,8 +202,8 @@ def test_guard_import( ) def test_guard_import_failure( module_name: str, - pip_name: Optional[str], - package: Optional[str], + pip_name: str | None, + package: str | None, expected_pip_name: str, ) -> None: with pytest.raises( @@ -273,7 +274,7 @@ def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> N monkeypatch.setenv("TEST_KEY", "secret_value") # Get the function - get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY") + get_secret: Callable[[], SecretStr | None] = secret_from_env("TEST_KEY") # Assert that it returns the correct value assert get_secret() == SecretStr("secret_value") @@ -297,7 +298,7 @@ def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> N monkeypatch.delenv("TEST_KEY", raising=False) # Get the function with a default value of None - get_secret: Callable[[], Optional[SecretStr]] = secret_from_env( + get_secret: Callable[[], SecretStr | None] = secret_from_env( "TEST_KEY", default=None ) @@ -350,14 +351,14 @@ def test_using_secret_from_env_as_default_factory( assert Foo().secret.get_secret_value() == "secret_value" class Bar(BaseModel): - secret: Optional[SecretStr] = Field( + secret: SecretStr | None = Field( default_factory=secret_from_env("TEST_KEY_2", default=None) ) assert Bar().secret is None class Buzz(BaseModel): - secret: Optional[SecretStr] = Field( + secret: SecretStr | None = Field( default_factory=secret_from_env("TEST_KEY_2", default="hello") ) @@ -365,7 +366,7 @@ def test_using_secret_from_env_as_default_factory( assert Buzz().secret.get_secret_value() == "hello" # type: ignore[union-attr] class OhMy(BaseModel): - secret: Optional[SecretStr] = Field( + secret: SecretStr | None = Field( default_factory=secret_from_env("FOOFOOFOOBAR") ) diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 75fed8fea84..e3904b7ad69 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -7,7 +7,7 @@ the relevant methods. from __future__ import annotations import uuid -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import pytest from typing_extensions import override @@ -30,8 +30,8 @@ class CustomAddTextsVectorstore(VectorStore): def add_texts( self, texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, + metadatas: list[dict] | None = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: if not isinstance(texts, list): @@ -42,7 +42,7 @@ class CustomAddTextsVectorstore(VectorStore): metadatas_ = metadatas or [{} for _ in texts] - for text, metadata in zip(texts, metadatas_ or []): + for text, metadata in zip(texts, metadatas_ or [], strict=False): next_id = next(ids_iter, None) id_ = next_id or str(uuid.uuid4()) self.store[id_] = Document(page_content=text, metadata=metadata, id=id_) @@ -58,7 +58,7 @@ class CustomAddTextsVectorstore(VectorStore): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> CustomAddTextsVectorstore: vectorstore = CustomAddTextsVectorstore() @@ -82,7 +82,7 @@ class CustomAddDocumentsVectorstore(VectorStore): self, documents: list[Document], *, - ids: Optional[list[str]] = None, + ids: list[str] | None = None, **kwargs: Any, ) -> list[str]: ids_ = [] @@ -104,7 +104,7 @@ class CustomAddDocumentsVectorstore(VectorStore): cls, texts: list[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: list[dict] | None = None, **kwargs: Any, ) -> CustomAddDocumentsVectorstore: vectorstore = CustomAddDocumentsVectorstore()