diff --git a/libs/core/langchain_core/agents.py b/libs/core/langchain_core/agents.py index 032d54372ef..3e76591f229 100644 --- a/libs/core/langchain_core/agents.py +++ b/libs/core/langchain_core/agents.py @@ -25,7 +25,7 @@ The schemas for the agents themselves are defined in langchain.agents.agent. from __future__ import annotations import json -from typing import Any, List, Literal, Sequence, Union +from typing import Any, Literal, Sequence, Union from langchain_core.load.serializable import Serializable from langchain_core.messages import ( @@ -71,7 +71,7 @@ class AgentAction(Serializable): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "agent"].""" return ["langchain", "schema", "agent"] @@ -145,7 +145,7 @@ class AgentFinish(Serializable): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "agent"] diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index f42142a577c..236d2e4875e 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -23,7 +23,7 @@ Cache directly competes with Memory. See documentation for Pros and Cons. from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence from langchain_core.outputs import Generation from langchain_core.runnables import run_in_executor @@ -157,7 +157,7 @@ class InMemoryCache(BaseCache): Raises: ValueError: If maxsize is less than or equal to 0. """ - self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} + self._cache: dict[tuple[str, str], RETURN_VAL_TYPE] = {} if maxsize is not None and maxsize <= 0: raise ValueError("maxsize must be greater than 0") self._maxsize = maxsize diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 819b3d07316..7137ca287df 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union from uuid import UUID from tenacity import RetryCallState @@ -118,7 +118,7 @@ class ChainManagerMixin: def on_chain_end( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -222,13 +222,13 @@ class CallbackManagerMixin: def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when LLM starts running. @@ -249,13 +249,13 @@ class CallbackManagerMixin: def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running. @@ -280,13 +280,13 @@ class CallbackManagerMixin: def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when the Retriever starts running. @@ -303,13 +303,13 @@ class CallbackManagerMixin: def on_chain_start( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chain starts running. @@ -326,14 +326,14 @@ class CallbackManagerMixin: def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inputs: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when the tool starts running. @@ -393,8 +393,8 @@ class RunManagerMixin: data: Any, *, run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Override to define a handler for a custom event. @@ -470,13 +470,13 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when LLM starts running. @@ -497,13 +497,13 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running. @@ -533,7 +533,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled. @@ -554,7 +554,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when LLM ends running. @@ -573,7 +573,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when LLM errors. @@ -590,13 +590,13 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chain_start( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when a chain starts running. @@ -613,11 +613,11 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chain_end( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when a chain ends running. @@ -636,7 +636,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when chain errors. @@ -651,14 +651,14 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inputs: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when the tool starts running. @@ -680,7 +680,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when the tool ends running. @@ -699,7 +699,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run when tool errors. @@ -718,7 +718,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on an arbitrary text. @@ -754,7 +754,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on agent action. @@ -773,7 +773,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on the agent end. @@ -788,13 +788,13 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run on the retriever start. @@ -815,7 +815,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on the retriever end. @@ -833,7 +833,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Run on retriever error. @@ -852,8 +852,8 @@ class AsyncCallbackHandler(BaseCallbackHandler): data: Any, *, run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Override to define a handler for a custom event. @@ -880,14 +880,14 @@ class BaseCallbackManager(CallbackManagerMixin): def __init__( self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + handlers: list[BaseCallbackHandler], + inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, parent_run_id: Optional[UUID] = None, *, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + inheritable_tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inheritable_metadata: Optional[dict[str, Any]] = None, ) -> None: """Initialize callback manager. @@ -901,8 +901,8 @@ class BaseCallbackManager(CallbackManagerMixin): Default is None. metadata (Optional[Dict[str, Any]]): The metadata. Default is None. """ - self.handlers: List[BaseCallbackHandler] = handlers - self.inheritable_handlers: List[BaseCallbackHandler] = ( + self.handlers: list[BaseCallbackHandler] = handlers + self.inheritable_handlers: list[BaseCallbackHandler] = ( inheritable_handlers or [] ) self.parent_run_id: Optional[UUID] = parent_run_id @@ -1002,7 +1002,7 @@ class BaseCallbackManager(CallbackManagerMixin): self.inheritable_handlers.remove(handler) def set_handlers( - self, handlers: List[BaseCallbackHandler], inherit: bool = True + self, handlers: list[BaseCallbackHandler], inherit: bool = True ) -> None: """Set handlers as the only handlers on the callback manager. @@ -1024,7 +1024,7 @@ class BaseCallbackManager(CallbackManagerMixin): """ self.set_handlers([handler], inherit=inherit) - def add_tags(self, tags: List[str], inherit: bool = True) -> None: + def add_tags(self, tags: list[str], inherit: bool = True) -> None: """Add tags to the callback manager. Args: @@ -1038,7 +1038,7 @@ class BaseCallbackManager(CallbackManagerMixin): if inherit: self.inheritable_tags.extend(tags) - def remove_tags(self, tags: List[str]) -> None: + def remove_tags(self, tags: list[str]) -> None: """Remove tags from the callback manager. Args: @@ -1048,7 +1048,7 @@ class BaseCallbackManager(CallbackManagerMixin): self.tags.remove(tag) self.inheritable_tags.remove(tag) - def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: + def add_metadata(self, metadata: dict[str, Any], inherit: bool = True) -> None: """Add metadata to the callback manager. Args: @@ -1059,7 +1059,7 @@ class BaseCallbackManager(CallbackManagerMixin): if inherit: self.inheritable_metadata.update(metadata) - def remove_metadata(self, keys: List[str]) -> None: + def remove_metadata(self, keys: list[str]) -> None: """Remove metadata from the callback manager. Args: diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index c33bd4a441c..f3f695ccfa7 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, TextIO, cast +from typing import Any, Optional, TextIO, cast from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler @@ -35,7 +35,7 @@ class FileCallbackHandler(BaseCallbackHandler): self.file.close() def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain. @@ -51,7 +51,7 @@ class FileCallbackHandler(BaseCallbackHandler): file=self.file, ) - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain. Args: diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index ce38bc53b82..6155987dcd3 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -14,9 +14,7 @@ from typing import ( AsyncGenerator, Callable, Coroutine, - Dict, Generator, - List, Optional, Sequence, Type, @@ -64,12 +62,12 @@ def trace_as_chain_group( group_name: str, callback_manager: Optional[CallbackManager] = None, *, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> Generator[CallbackManagerForChainGroup, None, None]: """Get a callback manager for a chain group in a context manager. Useful for grouping different calls together as a single run even if @@ -144,12 +142,12 @@ async def atrace_as_chain_group( group_name: str, callback_manager: Optional[AsyncCallbackManager] = None, *, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: """Get an async callback manager for a chain group in a context manager. Useful for grouping different async calls together as a single run even if @@ -240,7 +238,7 @@ def shielded(func: Func) -> Func: def handle_event( - handlers: List[BaseCallbackHandler], + handlers: list[BaseCallbackHandler], event_name: str, ignore_condition_name: Optional[str], *args: Any, @@ -258,10 +256,10 @@ def handle_event( *args: The arguments to pass to the event handler. **kwargs: The keyword arguments to pass to the event handler """ - coros: List[Coroutine[Any, Any, Any]] = [] + coros: list[Coroutine[Any, Any, Any]] = [] try: - message_strings: Optional[List[str]] = None + message_strings: Optional[list[str]] = None for handler in handlers: try: if ignore_condition_name is None or not getattr( @@ -318,7 +316,7 @@ def handle_event( _run_coros(coros) -def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: +def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None: if hasattr(asyncio, "Runner"): # Python 3.11+ # Run the coroutines in a new event loop, taking care to @@ -399,7 +397,7 @@ async def _ahandle_event_for_handler( async def ahandle_event( - handlers: List[BaseCallbackHandler], + handlers: list[BaseCallbackHandler], event_name: str, ignore_condition_name: Optional[str], *args: Any, @@ -446,13 +444,13 @@ class BaseRunManager(RunManagerMixin): self, *, run_id: UUID, - handlers: List[BaseCallbackHandler], - inheritable_handlers: List[BaseCallbackHandler], + handlers: list[BaseCallbackHandler], + inheritable_handlers: list[BaseCallbackHandler], parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - inheritable_tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + inheritable_tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + inheritable_metadata: Optional[dict[str, Any]] = None, ) -> None: """Initialize the run manager. @@ -481,7 +479,7 @@ class BaseRunManager(RunManagerMixin): self.inheritable_metadata = inheritable_metadata or {} @classmethod - def get_noop_manager(cls: Type[BRM]) -> BRM: + def get_noop_manager(cls: type[BRM]) -> BRM: """Return a manager that doesn't perform any operations. Returns: @@ -824,7 +822,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): """Callback manager for chain run.""" - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None: """Run when chain ends running. Args: @@ -929,7 +927,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): @shielded async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + self, outputs: Union[dict[str, Any], Any], **kwargs: Any ) -> None: """Run when a chain ends running. @@ -1248,11 +1246,11 @@ class CallbackManager(BaseCallbackManager): def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: + ) -> list[CallbackManagerForLLMRun]: """Run when LLM starts running. Args: @@ -1299,11 +1297,11 @@ class CallbackManager(BaseCallbackManager): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> List[CallbackManagerForLLMRun]: + ) -> list[CallbackManagerForLLMRun]: """Run when LLM starts running. Args: @@ -1354,8 +1352,8 @@ class CallbackManager(BaseCallbackManager): def on_chain_start( self, - serialized: Optional[Dict[str, Any]], - inputs: Union[Dict[str, Any], Any], + serialized: Optional[dict[str, Any]], + inputs: Union[dict[str, Any], Any], run_id: Optional[UUID] = None, **kwargs: Any, ) -> CallbackManagerForChainRun: @@ -1398,11 +1396,11 @@ class CallbackManager(BaseCallbackManager): def on_tool_start( self, - serialized: Optional[Dict[str, Any]], + serialized: Optional[dict[str, Any]], input_str: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> CallbackManagerForToolRun: """Run when tool starts running. @@ -1453,7 +1451,7 @@ class CallbackManager(BaseCallbackManager): def on_retriever_start( self, - serialized: Optional[Dict[str, Any]], + serialized: Optional[dict[str, Any]], query: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, @@ -1541,10 +1539,10 @@ class CallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, + inheritable_tags: Optional[list[str]] = None, + local_tags: Optional[list[str]] = None, + inheritable_metadata: Optional[dict[str, Any]] = None, + local_metadata: Optional[dict[str, Any]] = None, ) -> CallbackManager: """Configure the callback manager. @@ -1583,8 +1581,8 @@ class CallbackManagerForChainGroup(CallbackManager): def __init__( self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + handlers: list[BaseCallbackHandler], + inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, parent_run_id: Optional[UUID] = None, *, parent_run_manager: CallbackManagerForChainRun, @@ -1681,7 +1679,7 @@ class CallbackManagerForChainGroup(CallbackManager): manager.add_handler(handler, inherit=True) return manager - def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: Union[dict[str, Any], Any], **kwargs: Any) -> None: """Run when traced chain group ends. Args: @@ -1716,11 +1714,11 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: + ) -> list[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running. Args: @@ -1779,11 +1777,11 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> List[AsyncCallbackManagerForLLMRun]: + ) -> list[AsyncCallbackManagerForLLMRun]: """Async run when LLM starts running. Args: @@ -1840,8 +1838,8 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_chain_start( self, - serialized: Optional[Dict[str, Any]], - inputs: Union[Dict[str, Any], Any], + serialized: Optional[dict[str, Any]], + inputs: Union[dict[str, Any], Any], run_id: Optional[UUID] = None, **kwargs: Any, ) -> AsyncCallbackManagerForChainRun: @@ -1886,7 +1884,7 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_tool_start( self, - serialized: Optional[Dict[str, Any]], + serialized: Optional[dict[str, Any]], input_str: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, @@ -1975,7 +1973,7 @@ class AsyncCallbackManager(BaseCallbackManager): async def on_retriever_start( self, - serialized: Optional[Dict[str, Any]], + serialized: Optional[dict[str, Any]], query: str, run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None, @@ -2027,10 +2025,10 @@ class AsyncCallbackManager(BaseCallbackManager): inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, + inheritable_tags: Optional[list[str]] = None, + local_tags: Optional[list[str]] = None, + inheritable_metadata: Optional[dict[str, Any]] = None, + local_metadata: Optional[dict[str, Any]] = None, ) -> AsyncCallbackManager: """Configure the async callback manager. @@ -2069,8 +2067,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): def __init__( self, - handlers: List[BaseCallbackHandler], - inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + handlers: list[BaseCallbackHandler], + inheritable_handlers: Optional[list[BaseCallbackHandler]] = None, parent_run_id: Optional[UUID] = None, *, parent_run_manager: AsyncCallbackManagerForChainRun, @@ -2169,7 +2167,7 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): return manager async def on_chain_end( - self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + self, outputs: Union[dict[str, Any], Any], **kwargs: Any ) -> None: """Run when traced chain group ends. @@ -2202,14 +2200,14 @@ H = TypeVar("H", bound=BaseCallbackHandler, covariant=True) def _configure( - callback_manager_cls: Type[T], + callback_manager_cls: type[T], inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, verbose: bool = False, - inheritable_tags: Optional[List[str]] = None, - local_tags: Optional[List[str]] = None, - inheritable_metadata: Optional[Dict[str, Any]] = None, - local_metadata: Optional[Dict[str, Any]] = None, + inheritable_tags: Optional[list[str]] = None, + local_tags: Optional[list[str]] = None, + inheritable_metadata: Optional[dict[str, Any]] = None, + local_metadata: Optional[dict[str, Any]] = None, ) -> T: """Configure the callback manager. diff --git a/libs/core/langchain_core/callbacks/stdout.py b/libs/core/langchain_core/callbacks/stdout.py index 011dc83fcb9..bcdb5317ff9 100644 --- a/libs/core/langchain_core/callbacks/stdout.py +++ b/libs/core/langchain_core/callbacks/stdout.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.utils import print_text @@ -23,7 +23,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): self.color = color def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain. @@ -35,7 +35,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): class_name = serialized.get("name", serialized.get("id", [""])[-1]) print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201 - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain. Args: diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index 06973035a90..4ae0bcb419a 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any from langchain_core.callbacks.base import BaseCallbackHandler @@ -17,7 +17,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """Callback handler for streaming. Only works with LLMs that support streaming.""" def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: """Run when LLM starts running. @@ -29,8 +29,8 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any, ) -> None: """Run when LLM starts running. @@ -68,7 +68,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """ def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Run when a chain starts running. @@ -78,7 +78,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): **kwargs (Any): Additional keyword arguments. """ - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Run when a chain ends running. Args: @@ -95,7 +95,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """ def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + self, serialized: dict[str, Any], input_str: str, **kwargs: Any ) -> None: """Run when the tool starts running. diff --git a/libs/core/langchain_core/chat_history.py b/libs/core/langchain_core/chat_history.py index dd5a8afaa88..77138f3d8fb 100644 --- a/libs/core/langchain_core/chat_history.py +++ b/libs/core/langchain_core/chat_history.py @@ -18,7 +18,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Sequence, Union +from typing import Sequence, Union from pydantic import BaseModel, Field @@ -87,7 +87,7 @@ class BaseChatMessageHistory(ABC): f.write("[]") """ - messages: List[BaseMessage] + messages: list[BaseMessage] """A property or attribute that returns a list of messages. In general, getting the messages may involve IO to the underlying @@ -95,7 +95,7 @@ class BaseChatMessageHistory(ABC): latency. """ - async def aget_messages(self) -> List[BaseMessage]: + async def aget_messages(self) -> list[BaseMessage]: """Async version of getting messages. Can over-ride this method to provide an efficient async implementation. @@ -204,10 +204,10 @@ class InMemoryChatMessageHistory(BaseChatMessageHistory, BaseModel): Stores messages in a memory list. """ - messages: List[BaseMessage] = Field(default_factory=list) + messages: list[BaseMessage] = Field(default_factory=list) """A list of messages stored in memory.""" - async def aget_messages(self) -> List[BaseMessage]: + async def aget_messages(self) -> list[BaseMessage]: """Async version of getting messages. Can over-ride this method to provide an efficient async implementation. diff --git a/libs/core/langchain_core/document_loaders/base.py b/libs/core/langchain_core/document_loaders/base.py index c793285159e..87955ee5840 100644 --- a/libs/core/langchain_core/document_loaders/base.py +++ b/libs/core/langchain_core/document_loaders/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional from langchain_core.documents import Document from langchain_core.runnables import run_in_executor @@ -25,17 +25,17 @@ class BaseLoader(ABC): # noqa: B024 # Sub-classes should not implement this method directly. Instead, they # should implement the lazy load method. - def load(self) -> List[Document]: + def load(self) -> list[Document]: """Load data into Document objects.""" return list(self.lazy_load()) - async def aload(self) -> List[Document]: + async def aload(self) -> list[Document]: """Load data into Document objects.""" return [document async for document in self.alazy_load()] def load_and_split( self, text_splitter: Optional[TextSplitter] = None - ) -> List[Document]: + ) -> list[Document]: """Load Documents and split into chunks. Chunks are returned as Documents. Do not override this method. It should be considered to be deprecated! @@ -108,7 +108,7 @@ class BaseBlobParser(ABC): Generator of documents """ - def parse(self, blob: Blob) -> List[Document]: + def parse(self, blob: Blob) -> list[Document]: """Eagerly parse the blob into a document or documents. This is a convenience method for interactive development environment. diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 1586322b1b1..0cb357d8bb4 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -4,7 +4,7 @@ import contextlib import mimetypes from io import BufferedReader, BytesIO from pathlib import PurePath -from typing import Any, Dict, Generator, List, Literal, Optional, Union, cast +from typing import Any, Generator, Literal, Optional, Union, cast from pydantic import ConfigDict, Field, field_validator, model_validator @@ -138,7 +138,7 @@ class Blob(BaseMedia): @model_validator(mode="before") @classmethod - def check_blob_is_valid(cls, values: Dict[str, Any]) -> Any: + def check_blob_is_valid(cls, values: dict[str, Any]) -> Any: """Verify that either data or path is provided.""" if "data" not in values and "path" not in values: raise ValueError("Either data or path must be provided") @@ -285,7 +285,7 @@ class Document(BaseMedia): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "document"] diff --git a/libs/core/langchain_core/example_selectors/semantic_similarity.py b/libs/core/langchain_core/example_selectors/semantic_similarity.py index 77c7186088d..b27122ec36d 100644 --- a/libs/core/langchain_core/example_selectors/semantic_similarity.py +++ b/libs/core/langchain_core/example_selectors/semantic_similarity.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Optional from pydantic import BaseModel, ConfigDict @@ -15,7 +15,7 @@ if TYPE_CHECKING: from langchain_core.embeddings import Embeddings -def sorted_values(values: Dict[str, str]) -> List[Any]: +def sorted_values(values: dict[str, str]) -> list[Any]: """Return a list of values in dict sorted by key. Args: @@ -35,12 +35,12 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): """VectorStore that contains information about examples.""" k: int = 4 """Number of examples to select.""" - example_keys: Optional[List[str]] = None + example_keys: Optional[list[str]] = None """Optional keys to filter examples to.""" - input_keys: Optional[List[str]] = None + input_keys: Optional[list[str]] = None """Optional keys to filter input to. If provided, the search is based on the input variables instead of all variables.""" - vectorstore_kwargs: Optional[Dict[str, Any]] = None + vectorstore_kwargs: Optional[dict[str, Any]] = None """Extra arguments passed to similarity_search function of the vectorstore.""" model_config = ConfigDict( @@ -50,14 +50,14 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): @staticmethod def _example_to_text( - example: Dict[str, str], input_keys: Optional[List[str]] + example: dict[str, str], input_keys: Optional[list[str]] ) -> str: if input_keys: return " ".join(sorted_values({key: example[key] for key in input_keys})) else: return " ".join(sorted_values(example)) - def _documents_to_examples(self, documents: List[Document]) -> List[dict]: + def _documents_to_examples(self, documents: list[Document]) -> list[dict]: # Get the examples from the metadata. # This assumes that examples are stored in metadata. examples = [dict(e.metadata) for e in documents] @@ -66,7 +66,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): examples = [{k: eg[k] for k in self.example_keys} for eg in examples] return examples - def add_example(self, example: Dict[str, str]) -> str: + def add_example(self, example: dict[str, str]) -> str: """Add a new example to vectorstore. Args: @@ -81,7 +81,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): ) return ids[0] - async def aadd_example(self, example: Dict[str, str]) -> str: + async def aadd_example(self, example: dict[str, str]) -> str: """Async add new example to vectorstore. Args: @@ -100,7 +100,7 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): """Select examples based on semantic similarity.""" - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: """Select examples based on semantic similarity. Args: @@ -118,7 +118,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): ) return self._documents_to_examples(example_docs) - async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: """Asynchronously select examples based on semantic similarity. Args: @@ -139,13 +139,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): @classmethod def from_examples( cls, - examples: List[dict], + examples: list[dict], embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], + vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[List[str]] = None, + input_keys: Optional[list[str]] = None, *, - example_keys: Optional[List[str]] = None, + example_keys: Optional[list[str]] = None, vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> SemanticSimilarityExampleSelector: @@ -183,13 +183,13 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): @classmethod async def afrom_examples( cls, - examples: List[dict], + examples: list[dict], embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], + vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[List[str]] = None, + input_keys: Optional[list[str]] = None, *, - example_keys: Optional[List[str]] = None, + example_keys: Optional[list[str]] = None, vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> SemanticSimilarityExampleSelector: @@ -235,7 +235,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): fetch_k: int = 20 """Number of examples to fetch to rerank.""" - def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict]: """Select examples based on Max Marginal Relevance. Args: @@ -251,7 +251,7 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): ) return self._documents_to_examples(example_docs) - async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: """Asynchronously select examples based on Max Marginal Relevance. Args: @@ -270,13 +270,13 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): @classmethod def from_examples( cls, - examples: List[dict], + examples: list[dict], embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], + vectorstore_cls: type[VectorStore], k: int = 4, - input_keys: Optional[List[str]] = None, + input_keys: Optional[list[str]] = None, fetch_k: int = 20, - example_keys: Optional[List[str]] = None, + example_keys: Optional[list[str]] = None, vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: @@ -317,14 +317,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): @classmethod async def afrom_examples( cls, - examples: List[dict], + examples: list[dict], embeddings: Embeddings, - vectorstore_cls: Type[VectorStore], + vectorstore_cls: type[VectorStore], *, k: int = 4, - input_keys: Optional[List[str]] = None, + input_keys: Optional[list[str]] = None, fetch_k: int = 20, - example_keys: Optional[List[str]] = None, + example_keys: Optional[list[str]] = None, vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: diff --git a/libs/core/langchain_core/graph_vectorstores/base.py b/libs/core/langchain_core/graph_vectorstores/base.py index 4110be32e88..d3e7da93ffb 100644 --- a/libs/core/langchain_core/graph_vectorstores/base.py +++ b/libs/core/langchain_core/graph_vectorstores/base.py @@ -8,7 +8,6 @@ from typing import ( Collection, Iterable, Iterator, - List, Optional, ) @@ -68,7 +67,7 @@ class Node(Serializable): """Text contained by the node.""" metadata: dict = Field(default_factory=dict) """Metadata for the node.""" - links: List[Link] = Field(default_factory=list) + links: list[Link] = Field(default_factory=list) """Links associated with the node.""" @@ -189,7 +188,7 @@ class GraphVectorStore(VectorStore): *, ids: Optional[Iterable[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Run more texts through the embeddings and add to the vectorstore. The Links present in the metadata field `links` will be extracted to create @@ -237,7 +236,7 @@ class GraphVectorStore(VectorStore): *, ids: Optional[Iterable[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Run more texts through the embeddings and add to the vectorstore. The Links present in the metadata field `links` will be extracted to create @@ -282,7 +281,7 @@ class GraphVectorStore(VectorStore): self, documents: Iterable[Document], **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Run more documents through the embeddings and add to the vectorstore. The Links present in the document metadata field `links` will be extracted to @@ -332,7 +331,7 @@ class GraphVectorStore(VectorStore): self, documents: Iterable[Document], **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Run more documents through the embeddings and add to the vectorstore. The Links present in the document metadata field `links` will be extracted to @@ -535,7 +534,7 @@ class GraphVectorStore(VectorStore): def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: return list(self.traversal_search(query, k=k, depth=0)) def max_marginal_relevance_search( @@ -545,7 +544,7 @@ class GraphVectorStore(VectorStore): fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: return list( self.mmr_traversal_search( query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0 @@ -554,10 +553,10 @@ class GraphVectorStore(VectorStore): async def asimilarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: return [doc async for doc in self.atraversal_search(query, k=k, depth=0)] - def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: + def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]: if search_type == "similarity": return self.similarity_search(query, **kwargs) elif search_type == "similarity_score_threshold": @@ -580,7 +579,7 @@ class GraphVectorStore(VectorStore): async def asearch( self, query: str, search_type: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: if search_type == "similarity": return await self.asimilarity_search(query, **kwargs) elif search_type == "similarity_score_threshold": @@ -679,7 +678,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: if self.search_type == "traversal": return list(self.vectorstore.traversal_search(query, **self.search_kwargs)) elif self.search_type == "mmr_traversal": @@ -691,7 +690,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: if self.search_type == "traversal": return [ doc diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 5f0fb5bba46..7a9e612c783 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -11,14 +11,11 @@ from typing import ( AsyncIterable, AsyncIterator, Callable, - Dict, Iterable, Iterator, - List, Literal, Optional, Sequence, - Set, TypedDict, TypeVar, Union, @@ -71,7 +68,7 @@ class _HashedDocument(Document): @model_validator(mode="before") @classmethod - def calculate_hashes(cls, values: Dict[str, Any]) -> Any: + def calculate_hashes(cls, values: dict[str, Any]) -> Any: """Root validator to calculate content and metadata hash.""" content = values.get("page_content", "") metadata = values.get("metadata", {}) @@ -125,7 +122,7 @@ class _HashedDocument(Document): ) -def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: +def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]: """Utility batching function.""" it = iter(iterable) while True: @@ -135,9 +132,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: yield chunk -async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]: +async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]: """Utility batching function.""" - batch: List[T] = [] + batch: list[T] = [] async for element in iterable: if len(batch) < size: batch.append(element) @@ -171,7 +168,7 @@ def _deduplicate_in_order( hashed_documents: Iterable[_HashedDocument], ) -> Iterator[_HashedDocument]: """Deduplicate a list of hashed documents while preserving order.""" - seen: Set[str] = set() + seen: set[str] = set() for hashed_doc in hashed_documents: if hashed_doc.hash_ not in seen: @@ -349,7 +346,7 @@ def index( uids = [] docs_to_index = [] uids_to_refresh = [] - seen_docs: Set[str] = set() + seen_docs: set[str] = set() for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): if doc_exists: if force_update: @@ -589,7 +586,7 @@ async def aindex( uids: list[str] = [] docs_to_index: list[Document] = [] uids_to_refresh = [] - seen_docs: Set[str] = set() + seen_docs: set[str] = set() for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): if doc_exists: if force_update: diff --git a/libs/core/langchain_core/indexing/base.py b/libs/core/langchain_core/indexing/base.py index 24683a5f8eb..de4445783b1 100644 --- a/libs/core/langchain_core/indexing/base.py +++ b/libs/core/langchain_core/indexing/base.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, TypedDict +from typing import Any, Optional, Sequence, TypedDict from langchain_core._api import beta from langchain_core.documents import Document @@ -144,7 +144,7 @@ class RecordManager(ABC): """ @abstractmethod - def exists(self, keys: Sequence[str]) -> List[bool]: + def exists(self, keys: Sequence[str]) -> list[bool]: """Check if the provided keys exist in the database. Args: @@ -155,7 +155,7 @@ class RecordManager(ABC): """ @abstractmethod - async def aexists(self, keys: Sequence[str]) -> List[bool]: + async def aexists(self, keys: Sequence[str]) -> list[bool]: """Asynchronously check if the provided keys exist in the database. Args: @@ -173,7 +173,7 @@ class RecordManager(ABC): after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, limit: Optional[int] = None, - ) -> List[str]: + ) -> list[str]: """List records in the database based on the provided filters. Args: @@ -194,7 +194,7 @@ class RecordManager(ABC): after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, limit: Optional[int] = None, - ) -> List[str]: + ) -> list[str]: """Asynchronously list records in the database based on the provided filters. Args: @@ -241,7 +241,7 @@ class InMemoryRecordManager(RecordManager): super().__init__(namespace) # Each key points to a dictionary # of {'group_id': group_id, 'updated_at': timestamp} - self.records: Dict[str, _Record] = {} + self.records: dict[str, _Record] = {} self.namespace = namespace def create_schema(self) -> None: @@ -325,7 +325,7 @@ class InMemoryRecordManager(RecordManager): """ self.update(keys, group_ids=group_ids, time_at_least=time_at_least) - def exists(self, keys: Sequence[str]) -> List[bool]: + def exists(self, keys: Sequence[str]) -> list[bool]: """Check if the provided keys exist in the database. Args: @@ -336,7 +336,7 @@ class InMemoryRecordManager(RecordManager): """ return [key in self.records for key in keys] - async def aexists(self, keys: Sequence[str]) -> List[bool]: + async def aexists(self, keys: Sequence[str]) -> list[bool]: """Async check if the provided keys exist in the database. Args: @@ -354,7 +354,7 @@ class InMemoryRecordManager(RecordManager): after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, limit: Optional[int] = None, - ) -> List[str]: + ) -> list[str]: """List records in the database based on the provided filters. Args: @@ -390,7 +390,7 @@ class InMemoryRecordManager(RecordManager): after: Optional[float] = None, group_ids: Optional[Sequence[str]] = None, limit: Optional[int] = None, - ) -> List[str]: + ) -> list[str]: """Async list records in the database based on the provided filters. Args: @@ -449,9 +449,9 @@ class UpsertResponse(TypedDict): indexed to avoid this issue. """ - succeeded: List[str] + succeeded: list[str] """The IDs that were successfully indexed.""" - failed: List[str] + failed: list[str] """The IDs that failed to index.""" @@ -562,7 +562,7 @@ class DocumentIndex(BaseRetriever): ) @abc.abstractmethod - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> DeleteResponse: + def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> DeleteResponse: """Delete by IDs or other criteria. Calling delete without any input parameters should raise a ValueError! @@ -579,7 +579,7 @@ class DocumentIndex(BaseRetriever): """ async def adelete( - self, ids: Optional[List[str]] = None, **kwargs: Any + self, ids: Optional[list[str]] = None, **kwargs: Any ) -> DeleteResponse: """Delete by IDs or other criteria. Async variant. @@ -607,7 +607,7 @@ class DocumentIndex(BaseRetriever): ids: Sequence[str], /, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Get documents by id. Fewer documents may be returned than requested if some IDs are not found or @@ -633,7 +633,7 @@ class DocumentIndex(BaseRetriever): ids: Sequence[str], /, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Get documents by id. Fewer documents may be returned than requested if some IDs are not found or diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 0705f9e9f58..b21e54189ba 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -6,14 +6,11 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, List, Literal, Mapping, Optional, Sequence, - Set, - Type, TypeVar, Union, ) @@ -51,7 +48,7 @@ class LangSmithParams(TypedDict, total=False): """Temperature for generation.""" ls_max_tokens: Optional[int] """Max tokens for generation.""" - ls_stop: Optional[List[str]] + ls_stop: Optional[list[str]] """Stop words for generation.""" @@ -74,7 +71,7 @@ def get_tokenizer() -> Any: return GPT2TokenizerFast.from_pretrained("gpt2") -def _get_token_ids_default_method(text: str) -> List[int]: +def _get_token_ids_default_method(text: str) -> list[int]: """Encode the text into token IDs.""" # get the cached tokenizer tokenizer = get_tokenizer() @@ -117,11 +114,11 @@ class BaseLanguageModel( """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to add to the run trace.""" - tags: Optional[List[str]] = Field(default=None, exclude=True) + tags: Optional[list[str]] = Field(default=None, exclude=True) """Tags to add to the run trace.""" - metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) + metadata: Optional[dict[str, Any]] = Field(default=None, exclude=True) """Metadata to add to the run trace.""" - custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field( + custom_get_token_ids: Optional[Callable[[str], list[int]]] = Field( default=None, exclude=True ) """Optional encoder to use for counting tokens.""" @@ -167,8 +164,8 @@ class BaseLanguageModel( @abstractmethod def generate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -202,8 +199,8 @@ class BaseLanguageModel( @abstractmethod async def agenerate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -235,8 +232,8 @@ class BaseLanguageModel( """ def with_structured_output( - self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any - ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + self, schema: Union[dict, type[BaseModel]], **kwargs: Any + ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: """Not implemented on this class.""" # Implement this on child class if there is a way of steering the model to # generate responses that match a given schema. @@ -267,7 +264,7 @@ class BaseLanguageModel( @abstractmethod def predict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -313,7 +310,7 @@ class BaseLanguageModel( @abstractmethod async def apredict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -339,7 +336,7 @@ class BaseLanguageModel( """Get the identifying parameters.""" return self.lc_attributes - def get_token_ids(self, text: str) -> List[int]: + def get_token_ids(self, text: str) -> list[int]: """Return the ordered ids of the tokens in a text. Args: @@ -367,7 +364,7 @@ class BaseLanguageModel( """ return len(self.get_token_ids(text)) - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: """Get the number of tokens in the messages. Useful for checking if an input fits in a model's context window. @@ -381,7 +378,7 @@ class BaseLanguageModel( return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) @classmethod - def _all_required_field_names(cls) -> Set: + def _all_required_field_names(cls) -> set: """DEPRECATED: Kept for backwards compatibility. Use get_pydantic_field_names. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f761a8ed694..869c584854f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -15,11 +15,9 @@ from typing import ( Callable, Dict, Iterator, - List, Literal, Optional, Sequence, - Type, Union, cast, ) @@ -223,7 +221,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @model_validator(mode="before") @classmethod - def raise_deprecation(cls, values: Dict) -> Any: + def raise_deprecation(cls, values: dict) -> Any: """Raise deprecation warning if callback_manager is used. Args: @@ -277,7 +275,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: config = ensure_config(config) @@ -300,7 +298,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: config = ensure_config(config) @@ -356,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[BaseMessageChunk]: if not self._should_stream(async_api=False, **{**kwargs, **{"stream": True}}): @@ -426,7 +424,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: if not self._should_stream(async_api=True, **{**kwargs, **{"stream": True}}): @@ -499,12 +497,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): # --- Custom methods --- - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: return {} def _get_invocation_params( self, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> dict: params = self.dict() @@ -513,7 +511,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _get_ls_params( self, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -550,7 +548,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): return ls_params - def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str: + def _get_llm_string(self, stop: Optional[list[str]] = None, **kwargs: Any) -> str: if self.is_lc_serializable(): params = {**kwargs, **{"stop": stop}} param_string = str(sorted([(k, v) for k, v in params.items()])) @@ -567,12 +565,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def generate( self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, + messages: list[list[BaseMessage]], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, **kwargs: Any, @@ -658,12 +656,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def agenerate( self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, + messages: list[list[BaseMessage]], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, **kwargs: Any, @@ -777,8 +775,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def generate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -787,8 +785,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def agenerate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: @@ -799,8 +797,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _generate_with_cache( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -839,7 +837,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager=run_manager, **kwargs, ): - chunks: List[ChatGenerationChunk] = [] + chunks: list[ChatGenerationChunk] = [] for chunk in self._stream(messages, stop=stop, **kwargs): chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) if run_manager: @@ -876,8 +874,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def _agenerate_with_cache( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -916,7 +914,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager=run_manager, **kwargs, ): - chunks: List[ChatGenerationChunk] = [] + chunks: list[ChatGenerationChunk] = [] async for chunk in self._astream(messages, stop=stop, **kwargs): chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) if run_manager: @@ -954,8 +952,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @abstractmethod def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -963,8 +961,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def _agenerate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -980,8 +978,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _stream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: @@ -989,8 +987,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def _astream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: @@ -1017,8 +1015,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @deprecated("0.1.7", alternative="invoke", removal="1.0") def __call__( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> BaseMessage: @@ -1032,8 +1030,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): async def _call_async( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> BaseMessage: @@ -1048,7 +1046,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @deprecated("0.1.7", alternative="invoke", removal="1.0") def call_as_llm( - self, message: str, stop: Optional[List[str]] = None, **kwargs: Any + self, message: str, stop: Optional[list[str]] = None, **kwargs: Any ) -> str: return self.predict(message, stop=stop, **kwargs) @@ -1069,7 +1067,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @deprecated("0.1.7", alternative="invoke", removal="1.0") def predict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -1099,7 +1097,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): @deprecated("0.1.7", alternative="ainvoke", removal="1.0") async def apredict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -1115,7 +1113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _llm_type(self) -> str: """Return type of chat model.""" - def dict(self, **kwargs: Any) -> Dict: + def dict(self, **kwargs: Any) -> dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type @@ -1123,18 +1121,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def bind_tools( self, - tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006 **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: raise NotImplementedError() def with_structured_output( self, - schema: Union[Dict, Type], + schema: Union[Dict, type], # noqa: UP006 *, include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: # noqa: UP006 """Model wrapper that returns outputs formatted to match the given schema. Args: @@ -1281,8 +1279,8 @@ class SimpleChatModel(BaseChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: @@ -1294,8 +1292,8 @@ class SimpleChatModel(BaseChatModel): @abstractmethod def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -1303,8 +1301,8 @@ class SimpleChatModel(BaseChatModel): async def _agenerate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index d30249337d7..c1125f6202b 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -20,8 +20,6 @@ from typing import ( List, Optional, Sequence, - Tuple, - Type, Union, cast, ) @@ -76,7 +74,7 @@ def _log_error_once(msg: str) -> None: def create_base_retry_decorator( - error_types: List[Type[BaseException]], + error_types: list[type[BaseException]], max_retries: int = 1, run_manager: Optional[ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] @@ -153,10 +151,10 @@ def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]: def get_prompts( - params: Dict[str, Any], - prompts: List[str], + params: dict[str, Any], + prompts: list[str], cache: Optional[Union[BaseCache, bool, None]] = None, -) -> Tuple[Dict[int, List], str, List[int], List[str]]: +) -> tuple[dict[int, list], str, list[int], list[str]]: """Get prompts that are already cached. Args: @@ -189,10 +187,10 @@ def get_prompts( async def aget_prompts( - params: Dict[str, Any], - prompts: List[str], + params: dict[str, Any], + prompts: list[str], cache: Optional[Union[BaseCache, bool, None]] = None, -) -> Tuple[Dict[int, List], str, List[int], List[str]]: +) -> tuple[dict[int, list], str, list[int], list[str]]: """Get prompts that are already cached. Async version. Args: @@ -225,11 +223,11 @@ async def aget_prompts( def update_cache( cache: Union[BaseCache, bool, None], - existing_prompts: Dict[int, List], + existing_prompts: dict[int, list], llm_string: str, - missing_prompt_idxs: List[int], + missing_prompt_idxs: list[int], new_results: LLMResult, - prompts: List[str], + prompts: list[str], ) -> Optional[dict]: """Update the cache and get the LLM output. @@ -259,11 +257,11 @@ def update_cache( async def aupdate_cache( cache: Union[BaseCache, bool, None], - existing_prompts: Dict[int, List], + existing_prompts: dict[int, list], llm_string: str, - missing_prompt_idxs: List[int], + missing_prompt_idxs: list[int], new_results: LLMResult, - prompts: List[str], + prompts: list[str], ) -> Optional[dict]: """Update the cache and get the LLM output. Async version. @@ -306,7 +304,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): @model_validator(mode="before") @classmethod - def raise_deprecation(cls, values: Dict) -> Any: + def raise_deprecation(cls, values: dict) -> Any: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: warnings.warn( @@ -324,7 +322,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): # --- Runnable methods --- @property - def OutputType(self) -> Type[str]: + def OutputType(self) -> type[str]: """Get the input type for this runnable.""" return str @@ -343,7 +341,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _get_ls_params( self, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> LangSmithParams: """Get standard params for tracing.""" @@ -383,7 +381,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> str: config = ensure_config(config) @@ -407,7 +405,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> str: config = ensure_config(config) @@ -425,12 +423,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): def batch( self, - inputs: List[LanguageModelInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[LanguageModelInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: if not inputs: return [] @@ -472,12 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def abatch( self, - inputs: List[LanguageModelInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[LanguageModelInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: if not inputs: return [] config = get_config_list(config, len(inputs)) @@ -521,7 +519,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> Iterator[str]: if type(self)._stream == BaseLLM._stream: @@ -583,7 +581,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> AsyncIterator[str]: if ( @@ -649,8 +647,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): @abstractmethod def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -658,8 +656,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _agenerate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -676,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _stream( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: @@ -704,7 +702,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _astream( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: @@ -747,9 +745,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): def generate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, + callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -757,9 +755,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def agenerate_prompt( self, - prompts: List[PromptValue], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + prompts: list[PromptValue], + stop: Optional[list[str]] = None, + callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] @@ -769,9 +767,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _generate_helper( self, - prompts: List[str], - stop: Optional[List[str]], - run_managers: List[CallbackManagerForLLMRun], + prompts: list[str], + stop: Optional[list[str]], + run_managers: list[CallbackManagerForLLMRun], new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: @@ -802,14 +800,14 @@ class BaseLLM(BaseLanguageModel[str], ABC): def generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, + callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, *, - tags: Optional[Union[List[str], List[List[str]]]] = None, - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - run_name: Optional[Union[str, List[str]]] = None, - run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, + tags: Optional[Union[list[str], list[list[str]]]] = None, + metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, + run_name: Optional[Union[str, list[str]]] = None, + run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None, **kwargs: Any, ) -> LLMResult: """Pass a sequence of prompts to a model and return generations. @@ -987,7 +985,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): @staticmethod def _get_run_ids_list( - run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list + run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]], prompts: list ) -> list: if run_id is None: return [None] * len(prompts) @@ -1002,9 +1000,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _agenerate_helper( self, - prompts: List[str], - stop: Optional[List[str]], - run_managers: List[AsyncCallbackManagerForLLMRun], + prompts: list[str], + stop: Optional[list[str]], + run_managers: list[AsyncCallbackManagerForLLMRun], new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: @@ -1044,14 +1042,14 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def agenerate( self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, + callbacks: Optional[Union[Callbacks, list[Callbacks]]] = None, *, - tags: Optional[Union[List[str], List[List[str]]]] = None, - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - run_name: Optional[Union[str, List[str]]] = None, - run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, + tags: Optional[Union[list[str], list[list[str]]]] = None, + metadata: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, + run_name: Optional[Union[str, list[str]]] = None, + run_id: Optional[Union[uuid.UUID, list[Optional[uuid.UUID]]]] = None, **kwargs: Any, ) -> LLMResult: """Asynchronously pass a sequence of prompts to a model and return generations. @@ -1239,11 +1237,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): def __call__( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input. @@ -1287,11 +1285,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def _call_async( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" @@ -1318,7 +1316,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): @deprecated("0.1.7", alternative="invoke", removal="1.0") def predict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -1344,7 +1342,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): @deprecated("0.1.7", alternative="ainvoke", removal="1.0") async def apredict_messages( self, - messages: List[BaseMessage], + messages: list[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, @@ -1367,7 +1365,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def _llm_type(self) -> str: """Return type of llm.""" - def dict(self, **kwargs: Any) -> Dict: + def dict(self, **kwargs: Any) -> dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type @@ -1443,7 +1441,7 @@ class LLM(BaseLLM): def _call( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -1467,7 +1465,7 @@ class LLM(BaseLLM): async def _acall( self, prompt: str, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -1500,8 +1498,8 @@ class LLM(BaseLLM): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: @@ -1520,8 +1518,8 @@ class LLM(BaseLLM): async def _agenerate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: diff --git a/libs/core/langchain_core/memory.py b/libs/core/langchain_core/memory.py index cdbe5be656d..62816b4747f 100644 --- a/libs/core/langchain_core/memory.py +++ b/libs/core/langchain_core/memory.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any from pydantic import ConfigDict @@ -55,11 +55,11 @@ class BaseMemory(Serializable, ABC): @property @abstractmethod - def memory_variables(self) -> List[str]: + def memory_variables(self) -> list[str]: """The string keys this memory class will add to chain inputs.""" @abstractmethod - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return key-value pairs given the text input to the chain. Args: @@ -69,7 +69,7 @@ class BaseMemory(Serializable, ABC): A dictionary of key-value pairs. """ - async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Async return key-value pairs given the text input to the chain. Args: @@ -81,7 +81,7 @@ class BaseMemory(Serializable, ABC): return await run_in_executor(None, self.load_memory_variables, inputs) @abstractmethod - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None: """Save the context of this chain run to memory. Args: @@ -90,7 +90,7 @@ class BaseMemory(Serializable, ABC): """ async def asave_context( - self, inputs: Dict[str, Any], outputs: Dict[str, str] + self, inputs: dict[str, Any], outputs: dict[str, str] ) -> None: """Async save the context of this chain run to memory. diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 97d95c010da..574ef1c947d 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast from pydantic import ConfigDict, Field, field_validator @@ -19,7 +19,7 @@ class BaseMessage(Serializable): Messages are the inputs and outputs of ChatModels. """ - content: Union[str, List[Union[str, Dict]]] + content: Union[str, list[Union[str, dict]]] """The string contents of the message.""" additional_kwargs: dict = Field(default_factory=dict) @@ -64,7 +64,7 @@ class BaseMessage(Serializable): return id_value def __init__( - self, content: Union[str, List[Union[str, Dict]]], **kwargs: Any + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any ) -> None: """Pass in content as positional arg. @@ -85,7 +85,7 @@ class BaseMessage(Serializable): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"]. """ @@ -119,9 +119,9 @@ class BaseMessage(Serializable): def merge_content( - first_content: Union[str, List[Union[str, Dict]]], - *contents: Union[str, List[Union[str, Dict]]], -) -> Union[str, List[Union[str, Dict]]]: + first_content: Union[str, list[Union[str, dict]]], + *contents: Union[str, list[Union[str, dict]]], +) -> Union[str, list[Union[str, dict]]]: """Merge two message contents. Args: @@ -163,7 +163,7 @@ class BaseMessageChunk(BaseMessage): """Message chunk, which can be concatenated with other Message chunks.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Default is ["langchain", "schema", "messages"]. """ @@ -242,7 +242,7 @@ def message_to_dict(message: BaseMessage) -> dict: return {"type": message.type, "data": message.model_dump()} -def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: +def messages_to_dict(messages: Sequence[BaseMessage]) -> list[dict]: """Convert a sequence of Messages to a list of dictionaries. Args: diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index c18a85a1356..ce105f2049c 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -23,7 +23,6 @@ from typing import ( Optional, Sequence, Tuple, - Type, Union, cast, overload, @@ -166,7 +165,7 @@ def _message_from_dict(message: dict) -> BaseMessage: raise ValueError(f"Got unexpected message type: {_type}") -def messages_from_dict(messages: Sequence[dict]) -> List[BaseMessage]: +def messages_from_dict(messages: Sequence[dict]) -> list[BaseMessage]: """Convert a sequence of messages from dicts to Message objects. Args: @@ -208,7 +207,7 @@ def _create_message_from_message_type( content: str, name: Optional[str] = None, tool_call_id: Optional[str] = None, - tool_calls: Optional[List[Dict[str, Any]]] = None, + tool_calls: Optional[list[dict[str, Any]]] = None, id: Optional[str] = None, **additional_kwargs: Any, ) -> BaseMessage: @@ -230,7 +229,7 @@ def _create_message_from_message_type( ValueError: if the message type is not one of "human", "user", "ai", "assistant", "system", "function", or "tool". """ - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} if name is not None: kwargs["name"] = name if tool_call_id is not None: @@ -331,7 +330,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage: def convert_to_messages( messages: Union[Iterable[MessageLikeRepresentation], PromptValue], -) -> List[BaseMessage]: +) -> list[BaseMessage]: """Convert a sequence of messages to a list of messages. Args: @@ -352,18 +351,18 @@ def _runnable_support(func: Callable) -> Callable: @overload def wrapped( messages: Literal[None] = None, **kwargs: Any - ) -> Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]]: ... + ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ... @overload def wrapped( messages: Sequence[MessageLikeRepresentation], **kwargs: Any - ) -> List[BaseMessage]: ... + ) -> list[BaseMessage]: ... def wrapped( messages: Optional[Sequence[MessageLikeRepresentation]] = None, **kwargs: Any ) -> Union[ - List[BaseMessage], - Runnable[Sequence[MessageLikeRepresentation], List[BaseMessage]], + list[BaseMessage], + Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]], ]: from langchain_core.runnables.base import RunnableLambda @@ -382,11 +381,11 @@ def filter_messages( *, include_names: Optional[Sequence[str]] = None, exclude_names: Optional[Sequence[str]] = None, - include_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None, - exclude_types: Optional[Sequence[Union[str, Type[BaseMessage]]]] = None, + include_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None, + exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None, include_ids: Optional[Sequence[str]] = None, exclude_ids: Optional[Sequence[str]] = None, -) -> List[BaseMessage]: +) -> list[BaseMessage]: """Filter messages based on name, type or id. Args: @@ -438,7 +437,7 @@ def filter_messages( ] """ # noqa: E501 messages = convert_to_messages(messages) - filtered: List[BaseMessage] = [] + filtered: list[BaseMessage] = [] for msg in messages: if exclude_names and msg.name in exclude_names: continue @@ -469,7 +468,7 @@ def merge_message_runs( messages: Union[Iterable[MessageLikeRepresentation], PromptValue], *, chunk_separator: str = "\n", -) -> List[BaseMessage]: +) -> list[BaseMessage]: """Merge consecutive Messages of the same type. **NOTE**: ToolMessages are not merged, as each has a distinct tool call id that @@ -539,7 +538,7 @@ def merge_message_runs( if not messages: return [] messages = convert_to_messages(messages) - merged: List[BaseMessage] = [] + merged: list[BaseMessage] = [] for msg in messages: curr = msg.model_copy(deep=True) last = merged.pop() if merged else None @@ -569,21 +568,21 @@ def trim_messages( *, max_tokens: int, token_counter: Union[ - Callable[[List[BaseMessage]], int], + Callable[[list[BaseMessage]], int], Callable[[BaseMessage], int], BaseLanguageModel, ], strategy: Literal["first", "last"] = "last", allow_partial: bool = False, end_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] + Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] ] = None, start_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] + Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] ] = None, include_system: bool = False, - text_splitter: Optional[Union[Callable[[str], List[str]], TextSplitter]] = None, -) -> List[BaseMessage]: + text_splitter: Optional[Union[Callable[[str], list[str]], TextSplitter]] = None, +) -> list[BaseMessage]: """Trim messages to be below a token count. Args: @@ -875,13 +874,13 @@ def _first_max_tokens( messages: Sequence[BaseMessage], *, max_tokens: int, - token_counter: Callable[[List[BaseMessage]], int], - text_splitter: Callable[[str], List[str]], + token_counter: Callable[[list[BaseMessage]], int], + text_splitter: Callable[[str], list[str]], partial_strategy: Optional[Literal["first", "last"]] = None, end_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] + Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] ] = None, -) -> List[BaseMessage]: +) -> list[BaseMessage]: messages = list(messages) idx = 0 for i in range(len(messages)): @@ -949,17 +948,17 @@ def _last_max_tokens( messages: Sequence[BaseMessage], *, max_tokens: int, - token_counter: Callable[[List[BaseMessage]], int], - text_splitter: Callable[[str], List[str]], + token_counter: Callable[[list[BaseMessage]], int], + text_splitter: Callable[[str], list[str]], allow_partial: bool = False, include_system: bool = False, start_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] + Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] ] = None, end_on: Optional[ - Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]] + Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]] ] = None, -) -> List[BaseMessage]: +) -> list[BaseMessage]: messages = list(messages) if end_on: while messages and not _is_message_type(messages[-1], end_on): @@ -984,7 +983,7 @@ def _last_max_tokens( return reversed_[::-1] -_MSG_CHUNK_MAP: Dict[Type[BaseMessage], Type[BaseMessageChunk]] = { +_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = { HumanMessage: HumanMessageChunk, AIMessage: AIMessageChunk, SystemMessage: SystemMessageChunk, @@ -1024,14 +1023,14 @@ def _chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage: ) -def _default_text_splitter(text: str) -> List[str]: +def _default_text_splitter(text: str) -> list[str]: splits = text.split("\n") return [s + "\n" for s in splits[:-1]] + splits[-1:] def _is_message_type( message: BaseMessage, - type_: Union[str, Type[BaseMessage], Sequence[Union[str, Type[BaseMessage]]]], + type_: Union[str, type[BaseMessage], Sequence[Union[str, type[BaseMessage]]]], ) -> bool: types = [type_] if isinstance(type_, (str, type)) else type_ types_str = [t for t in types if isinstance(t, str)] diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 6e1420d9b71..a6d539ecc6b 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -4,11 +4,8 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - List, Optional, - Type, TypeVar, Union, ) @@ -30,7 +27,7 @@ class BaseLLMOutputParser(Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod - def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -44,7 +41,7 @@ class BaseLLMOutputParser(Generic[T], ABC): """ async def aparse_result( - self, result: List[Generation], *, partial: bool = False + self, result: list[Generation], *, partial: bool = False ) -> T: """Async parse a list of candidate model Generations into a specific format. @@ -71,7 +68,7 @@ class BaseGenerationOutputParser( return Union[str, AnyMessage] @property - def OutputType(self) -> Type[T]: + def OutputType(self) -> type[T]: """Return the output type for the parser.""" # even though mypy complains this isn't valid, # it is good enough for pydantic to build the schema from @@ -156,7 +153,7 @@ class BaseOutputParser( return Union[str, AnyMessage] @property - def OutputType(self) -> Type[T]: + def OutputType(self) -> type[T]: """Return the output type for the parser. This property is inferred from the first type argument of the class. @@ -218,7 +215,7 @@ class BaseOutputParser( run_type="parser", ) - def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -247,7 +244,7 @@ class BaseOutputParser( """ async def aparse_result( - self, result: List[Generation], *, partial: bool = False + self, result: list[Generation], *, partial: bool = False ) -> T: """Async parse a list of candidate model Generations into a specific format. @@ -305,7 +302,7 @@ class BaseOutputParser( " This is required for serialization." ) - def dict(self, **kwargs: Any) -> Dict: + def dict(self, **kwargs: Any) -> dict: """Return dictionary representation of output parser.""" output_parser_dict = super().dict(**kwargs) try: diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index b076e65e853..38c0823de23 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from json import JSONDecodeError -from typing import Any, List, Optional, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import jsonpatch # type: ignore[import] import pydantic @@ -42,14 +42,14 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): describing the difference between the previous and the current object. """ - pydantic_object: Annotated[Optional[Type[TBaseModel]], SkipValidation()] = None # type: ignore + pydantic_object: Annotated[Optional[type[TBaseModel]], SkipValidation()] = None # type: ignore """The Pydantic object to use for validation. If None, no validation is performed.""" def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - def _get_schema(self, pydantic_object: Type[TBaseModel]) -> dict[str, Any]: + def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: if PYDANTIC_MAJOR_VERSION == 2: if issubclass(pydantic_object, pydantic.BaseModel): return pydantic_object.model_json_schema() @@ -57,7 +57,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): return pydantic_object.model_json_schema() return pydantic_object.model_json_schema() - def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. Args: diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index 8227fa8c613..65ec406ec0a 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -3,7 +3,7 @@ from __future__ import annotations import re from abc import abstractmethod from collections import deque -from typing import AsyncIterator, Deque, Iterator, List, TypeVar, Union +from typing import AsyncIterator, Iterator, List, TypeVar, Union from typing import Optional as Optional from langchain_core.messages import BaseMessage @@ -22,7 +22,7 @@ def droplastn(iter: Iterator[T], n: int) -> Iterator[T]: Yields: The elements of the iterator, except the last n elements. """ - buffer: Deque[T] = deque() + buffer: deque[T] = deque() for item in iter: buffer.append(item) if len(buffer) > n: @@ -37,7 +37,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]): return "list" @abstractmethod - def parse(self, text: str) -> List[str]: + def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. Args: @@ -60,7 +60,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]): def _transform( self, input: Iterator[Union[str, BaseMessage]] - ) -> Iterator[List[str]]: + ) -> Iterator[list[str]]: buffer = "" for chunk in input: if isinstance(chunk, BaseMessage): @@ -92,7 +92,7 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]): async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] - ) -> AsyncIterator[List[str]]: + ) -> AsyncIterator[list[str]]: buffer = "" async for chunk in input: if isinstance(chunk, BaseMessage): @@ -136,7 +136,7 @@ class CommaSeparatedListOutputParser(ListOutputParser): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns: @@ -152,7 +152,7 @@ class CommaSeparatedListOutputParser(ListOutputParser): "eg: `foo, bar, baz` or `foo,bar,baz`" ) - def parse(self, text: str) -> List[str]: + def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. Args: @@ -180,7 +180,7 @@ class NumberedListOutputParser(ListOutputParser): "For example: \n\n1. foo\n\n2. bar\n\n3. baz" ) - def parse(self, text: str) -> List[str]: + def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. Args: @@ -217,7 +217,7 @@ class MarkdownListOutputParser(ListOutputParser): """Return the format instructions for the Markdown list output.""" return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" - def parse(self, text: str) -> List[str]: + def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. Args: diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index d64015ed6a2..599283e60e5 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Literal, Union +from typing import Literal, Union from pydantic import model_validator from typing_extensions import Self @@ -69,7 +69,7 @@ class ChatGeneration(Generation): return self @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "output"] @@ -86,12 +86,12 @@ class ChatGenerationChunk(ChatGeneration): """Type is used exclusively for serialization purposes.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "output"] def __add__( - self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]] + self, other: Union[ChatGenerationChunk, list[ChatGenerationChunk]] ) -> ChatGenerationChunk: if isinstance(other, ChatGenerationChunk): generation_info = merge_dicts( diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index 7dbb6896d87..7022e6e03e3 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional from langchain_core.load import Serializable from langchain_core.utils._merge import merge_dicts @@ -25,7 +25,7 @@ class Generation(Serializable): text: str """Generated text output.""" - generation_info: Optional[Dict[str, Any]] = None + generation_info: Optional[dict[str, Any]] = None """Raw response from the provider. May include things like the reason for finishing or token log probabilities. @@ -40,7 +40,7 @@ class Generation(Serializable): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "output"] @@ -49,7 +49,7 @@ class GenerationChunk(Generation): """Generation chunk, which can be concatenated with other Generation chunks.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "output"] diff --git a/libs/core/langchain_core/outputs/llm_result.py b/libs/core/langchain_core/outputs/llm_result.py index 98b8fe7fe67..8a429616dcf 100644 --- a/libs/core/langchain_core/outputs/llm_result.py +++ b/libs/core/langchain_core/outputs/llm_result.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from pydantic import BaseModel @@ -18,8 +18,8 @@ class LLMResult(BaseModel): wants to return. """ - generations: List[ - List[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]] + generations: list[ + list[Union[Generation, ChatGeneration, GenerationChunk, ChatGenerationChunk]] ] """Generated outputs. @@ -45,13 +45,13 @@ class LLMResult(BaseModel): accessing relevant information from standardized fields present in AIMessage. """ - run: Optional[List[RunInfo]] = None + run: Optional[list[RunInfo]] = None """List of metadata info for model call for each input.""" type: Literal["LLMResult"] = "LLMResult" # type: ignore[assignment] """Type is used exclusively for serialization purposes.""" - def flatten(self) -> List[LLMResult]: + def flatten(self) -> list[LLMResult]: """Flatten generations into a single list. Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 27bf3e5df7f..6b9f421890f 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -7,7 +7,7 @@ They can be used to represent text, images, or chat message pieces. from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Literal, Sequence, cast +from typing import Literal, Sequence, cast from typing_extensions import TypedDict @@ -33,7 +33,7 @@ class PromptValue(Serializable, ABC): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. This is used to determine the namespace of the object when serializing. Defaults to ["langchain", "schema", "prompt"]. @@ -45,7 +45,7 @@ class PromptValue(Serializable, ABC): """Return prompt value as string.""" @abstractmethod - def to_messages(self) -> List[BaseMessage]: + def to_messages(self) -> list[BaseMessage]: """Return prompt as a list of Messages.""" @@ -57,7 +57,7 @@ class StringPromptValue(PromptValue): type: Literal["StringPromptValue"] = "StringPromptValue" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. This is used to determine the namespace of the object when serializing. Defaults to ["langchain", "prompts", "base"]. @@ -68,7 +68,7 @@ class StringPromptValue(PromptValue): """Return prompt as string.""" return self.text - def to_messages(self) -> List[BaseMessage]: + def to_messages(self) -> list[BaseMessage]: """Return prompt as messages.""" return [HumanMessage(content=self.text)] @@ -86,12 +86,12 @@ class ChatPromptValue(PromptValue): """Return prompt as string.""" return get_buffer_string(self.messages) - def to_messages(self) -> List[BaseMessage]: + def to_messages(self) -> list[BaseMessage]: """Return prompt as a list of messages.""" return list(self.messages) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. This is used to determine the namespace of the object when serializing. Defaults to ["langchain", "prompts", "chat"]. @@ -121,7 +121,7 @@ class ImagePromptValue(PromptValue): """Return prompt (image URL) as string.""" return self.image_url["url"] - def to_messages(self) -> List[BaseMessage]: + def to_messages(self) -> list[BaseMessage]: """Return prompt (image URL) as messages.""" return [HumanMessage(content=[cast(dict, self.image_url)])] @@ -136,7 +136,7 @@ class ChatPromptValueConcrete(ChatPromptValue): type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. This is used to determine the namespace of the object when serializing. Defaults to ["langchain", "prompts", "chat"]. diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 62ce4cece81..531f40a6770 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -10,10 +10,8 @@ from typing import ( Callable, Dict, Generic, - List, Mapping, Optional, - Type, TypeVar, Union, ) @@ -45,14 +43,14 @@ class BasePromptTemplate( ): """Base class for all prompt templates, returning a prompt.""" - input_variables: List[str] + input_variables: list[str] """A list of the names of the variables whose values are required as inputs to the prompt.""" - optional_variables: List[str] = Field(default=[]) + optional_variables: list[str] = Field(default=[]) """optional_variables: A list of the names of the variables for placeholder or MessagePlaceholder that are optional. These variables are auto inferred from the prompt and user need not provide them.""" - input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) + input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006 """A dictionary of the types of the variables the prompt template expects. If not provided, all variables are assumed to be strings.""" output_parser: Optional[BaseOutputParser] = None @@ -62,9 +60,9 @@ class BasePromptTemplate( Partial variables populate the template so that you don't need to pass them in every time you call the prompt.""" - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None # noqa: UP006 """Metadata to be used for tracing.""" - tags: Optional[List[str]] = None + tags: Optional[list[str]] = None """Tags to be used for tracing.""" @model_validator(mode="after") @@ -89,7 +87,7 @@ class BasePromptTemplate( return self @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object. Returns ["langchain", "schema", "prompt_template"].""" return ["langchain", "schema", "prompt_template"] @@ -115,7 +113,7 @@ class BasePromptTemplate( def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get the input schema for the prompt. Args: @@ -136,7 +134,7 @@ class BasePromptTemplate( field_definitions={**required_input_variables, **optional_input_variables}, ) - def _validate_input(self, inner_input: Any) -> Dict: + def _validate_input(self, inner_input: Any) -> dict: if not isinstance(inner_input, dict): if len(self.input_variables) == 1: var_name = self.input_variables[0] @@ -163,18 +161,18 @@ class BasePromptTemplate( raise KeyError(msg) return inner_input - def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: + def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue: _inner_input = self._validate_input(inner_input) return self.format_prompt(**_inner_input) async def _aformat_prompt_with_error_handling( - self, inner_input: Dict + self, inner_input: dict ) -> PromptValue: _inner_input = self._validate_input(inner_input) return await self.aformat_prompt(**_inner_input) def invoke( - self, input: Dict, config: Optional[RunnableConfig] = None + self, input: dict, config: Optional[RunnableConfig] = None ) -> PromptValue: """Invoke the prompt. @@ -199,7 +197,7 @@ class BasePromptTemplate( ) async def ainvoke( - self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any + self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> PromptValue: """Async invoke the prompt. @@ -261,7 +259,7 @@ class BasePromptTemplate( prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} return type(self)(**prompt_dict) - def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: + def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]: # Get partial params: partial_kwargs = { k: v if not callable(v) else v() for k, v in self.partial_variables.items() @@ -307,7 +305,7 @@ class BasePromptTemplate( """Return the prompt type key.""" raise NotImplementedError - def dict(self, **kwargs: Any) -> Dict: + def dict(self, **kwargs: Any) -> dict: """Return dictionary representation of prompt. Args: @@ -369,7 +367,7 @@ class BasePromptTemplate( raise ValueError(f"{save_path} must be json or yaml") -def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict: +def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict: base_info = {"page_content": doc.page_content, **doc.metadata} missing_metadata = set(prompt.input_variables).difference(base_info) if len(missing_metadata) > 0: diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 0c3ec402cb3..c5e5f1b2792 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -6,12 +6,10 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, - Dict, List, Literal, Optional, Sequence, - Set, Tuple, Type, TypedDict, @@ -60,12 +58,12 @@ class BaseMessagePromptTemplate(Serializable, ABC): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @abstractmethod - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format messages from kwargs. Should return a list of BaseMessages. Args: @@ -75,7 +73,7 @@ class BaseMessagePromptTemplate(Serializable, ABC): List of BaseMessages. """ - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format messages from kwargs. Should return a list of BaseMessages. @@ -89,7 +87,7 @@ class BaseMessagePromptTemplate(Serializable, ABC): @property @abstractmethod - def input_variables(self) -> List[str]: + def input_variables(self) -> list[str]: """Input variables for this prompt template. Returns: @@ -210,7 +208,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): Defaults to None.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @@ -223,7 +221,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): variable_name=variable_name, optional=optional, **kwargs ) - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format messages from kwargs. Args: @@ -251,7 +249,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): return value @property - def input_variables(self) -> List[str]: + def input_variables(self) -> list[str]: """Input variables for this prompt template. Returns: @@ -292,16 +290,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): """Additional keyword arguments to pass to the prompt template.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @classmethod def from_template( - cls: Type[MessagePromptTemplateT], + cls: type[MessagePromptTemplateT], template: str, template_format: str = "f-string", - partial_variables: Optional[Dict[str, Any]] = None, + partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> MessagePromptTemplateT: """Create a class from a string template. @@ -329,9 +327,9 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): @classmethod def from_template_file( - cls: Type[MessagePromptTemplateT], + cls: type[MessagePromptTemplateT], template_file: Union[str, Path], - input_variables: List[str], + input_variables: list[str], **kwargs: Any, ) -> MessagePromptTemplateT: """Create a class from a template file. @@ -369,7 +367,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): """ return self.format(**kwargs) - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format messages from kwargs. Args: @@ -380,7 +378,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): """ return [self.format(**kwargs)] - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format messages from kwargs. Args: @@ -392,7 +390,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): return [await self.aformat(**kwargs)] @property - def input_variables(self) -> List[str]: + def input_variables(self) -> list[str]: """ Input variables for this prompt template. @@ -423,7 +421,7 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): """Role of the message.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @@ -462,37 +460,37 @@ _StringImageMessagePromptTemplateT = TypeVar( class _TextTemplateParam(TypedDict, total=False): - text: Union[str, Dict] + text: Union[str, dict] class _ImageTemplateParam(TypedDict, total=False): - image_url: Union[str, Dict] + image_url: Union[str, dict] class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" prompt: Union[ - StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]] + StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]] ] """Prompt template.""" additional_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the prompt template.""" - _msg_class: Type[BaseMessage] + _msg_class: type[BaseMessage] @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @classmethod def from_template( - cls: Type[_StringImageMessagePromptTemplateT], - template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]], + cls: type[_StringImageMessagePromptTemplateT], + template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template_format: str = "f-string", *, - partial_variables: Optional[Dict[str, Any]] = None, + partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> _StringImageMessagePromptTemplateT: """Create a class from a string template. @@ -511,7 +509,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): ValueError: If the template is not a string or list of strings. """ if isinstance(template, str): - prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template( + prompt: Union[StringPromptTemplate, list] = PromptTemplate.from_template( template, template_format=template_format, partial_variables=partial_variables, @@ -574,9 +572,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template_file( - cls: Type[_StringImageMessagePromptTemplateT], + cls: type[_StringImageMessagePromptTemplateT], template_file: Union[str, Path], - input_variables: List[str], + input_variables: list[str], **kwargs: Any, ) -> _StringImageMessagePromptTemplateT: """Create a class from a template file. @@ -593,7 +591,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): template = f.read() return cls.from_template(template, input_variables=input_variables, **kwargs) - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format messages from kwargs. Args: @@ -604,7 +602,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """ return [self.format(**kwargs)] - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format messages from kwargs. Args: @@ -616,7 +614,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): return [await self.aformat(**kwargs)] @property - def input_variables(self) -> List[str]: + def input_variables(self) -> list[str]: """ Input variables for this prompt template. @@ -642,7 +640,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): content=text, additional_kwargs=self.additional_kwargs ) else: - content: List = [] + content: list = [] for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate): @@ -670,7 +668,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): content=text, additional_kwargs=self.additional_kwargs ) else: - content: List = [] + content: list = [] for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate): @@ -703,16 +701,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" - _msg_class: Type[BaseMessage] = HumanMessage + _msg_class: type[BaseMessage] = HumanMessage class AIMessagePromptTemplate(_StringImageMessagePromptTemplate): """AI message prompt template. This is a message sent from the AI.""" - _msg_class: Type[BaseMessage] = AIMessage + _msg_class: type[BaseMessage] = AIMessage @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @@ -722,10 +720,10 @@ class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate): This is a message that is not sent to the user. """ - _msg_class: Type[BaseMessage] = SystemMessage + _msg_class: type[BaseMessage] = SystemMessage @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @@ -734,7 +732,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): """Base class for chat prompt templates.""" @property - def lc_attributes(self) -> Dict: + def lc_attributes(self) -> dict: """ Return a list of attribute names that should be included in the serialized kwargs. These attributes must be accepted by the @@ -791,10 +789,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): return ChatPromptValue(messages=messages) @abstractmethod - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format kwargs into a list of messages.""" - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format kwargs into a list of messages.""" return self.format_messages(**kwargs) @@ -935,7 +933,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """ # noqa: E501 - messages: Annotated[List[MessageLike], SkipValidation()] + messages: Annotated[list[MessageLike], SkipValidation()] """List of messages consisting of either message prompt templates or messages.""" validate_template: bool = False """Whether or not to try validating the template.""" @@ -999,9 +997,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate): ] # Automatically infer input variables from messages - input_vars: Set[str] = set() - optional_variables: Set[str] = set() - partial_vars: Dict[str, Any] = {} + input_vars: set[str] = set() + optional_variables: set[str] = set() + partial_vars: dict[str, Any] = {} for _message in _messages: if isinstance(_message, MessagesPlaceholder) and _message.optional: partial_vars[_message.variable_name] = [] @@ -1022,7 +1020,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] @@ -1071,7 +1069,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): messages = values["messages"] input_vars = set() optional_variables = set() - input_types: Dict[str, Any] = values.get("input_types", {}) + input_types: dict[str, Any] = values.get("input_types", {}) for message in messages: if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): input_vars.update(message.input_variables) @@ -1125,7 +1123,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): @classmethod @deprecated("0.0.1", alternative="from_messages classmethod", pending=True) def from_role_strings( - cls, string_messages: List[Tuple[str, str]] + cls, string_messages: list[tuple[str, str]] ) -> ChatPromptTemplate: """Create a chat prompt template from a list of (role, template) tuples. @@ -1145,7 +1143,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): @classmethod @deprecated("0.0.1", alternative="from_messages classmethod", pending=True) def from_strings( - cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] + cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]] ) -> ChatPromptTemplate: """Create a chat prompt template from a list of (role class, template) tuples. @@ -1200,7 +1198,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """ return cls(messages, template_format=template_format) - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format the chat template into a list of finalized messages. Args: @@ -1224,7 +1222,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): raise ValueError(f"Unexpected input: {message_template}") return result - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format the chat template into a list of finalized messages. Args: diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index a28c87cf0c8..ac20af37845 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -3,7 +3,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import ( BaseModel, @@ -31,7 +31,7 @@ from langchain_core.prompts.string import ( class _FewShotPromptTemplateMixin(BaseModel): """Prompt template that contains few shot examples.""" - examples: Optional[List[dict]] = None + examples: Optional[list[dict]] = None """Examples to format into the prompt. Either this or example_selector should be provided.""" @@ -46,7 +46,7 @@ class _FewShotPromptTemplateMixin(BaseModel): @model_validator(mode="before") @classmethod - def check_examples_and_selector(cls, values: Dict) -> Any: + def check_examples_and_selector(cls, values: dict) -> Any: """Check that one and only one of examples/example_selector are provided. Args: @@ -73,7 +73,7 @@ class _FewShotPromptTemplateMixin(BaseModel): return values - def _get_examples(self, **kwargs: Any) -> List[dict]: + def _get_examples(self, **kwargs: Any) -> list[dict]: """Get the examples to use for formatting the prompt. Args: @@ -94,7 +94,7 @@ class _FewShotPromptTemplateMixin(BaseModel): "One of 'examples' and 'example_selector' should be provided" ) - async def _aget_examples(self, **kwargs: Any) -> List[dict]: + async def _aget_examples(self, **kwargs: Any) -> list[dict]: """Async get the examples to use for formatting the prompt. Args: @@ -363,7 +363,7 @@ class FewShotChatMessagePromptTemplate( chain.invoke({"input": "What's 3+3?"}) """ - input_variables: List[str] = Field(default_factory=list) + input_variables: list[str] = Field(default_factory=list) """A list of the names of the variables the prompt template will use to pass to the example_selector, if provided.""" @@ -380,7 +380,7 @@ class FewShotChatMessagePromptTemplate( extra="forbid", ) - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format kwargs into a list of messages. Args: @@ -402,7 +402,7 @@ class FewShotChatMessagePromptTemplate( ] return messages - async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]: """Async format kwargs into a list of messages. Args: diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 9824721210f..659bf1de258 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -4,7 +4,7 @@ from __future__ import annotations import warnings from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, model_validator @@ -54,13 +54,13 @@ class PromptTemplate(StringPromptTemplate): """ @property - def lc_attributes(self) -> Dict[str, Any]: + def lc_attributes(self) -> dict[str, Any]: return { "template_format": self.template_format, } @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "prompt"] @@ -76,7 +76,7 @@ class PromptTemplate(StringPromptTemplate): @model_validator(mode="before") @classmethod - def pre_init_validation(cls, values: Dict) -> Any: + def pre_init_validation(cls, values: dict) -> Any: """Check that template and input variables are consistent.""" if values.get("template") is None: # Will let pydantic fail with a ValidationError if template @@ -183,9 +183,9 @@ class PromptTemplate(StringPromptTemplate): @classmethod def from_examples( cls, - examples: List[str], + examples: list[str], suffix: str, - input_variables: List[str], + input_variables: list[str], example_separator: str = "\n\n", prefix: str = "", **kwargs: Any, @@ -215,7 +215,7 @@ class PromptTemplate(StringPromptTemplate): def from_file( cls, template_file: Union[str, Path], - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, encoding: Optional[str] = None, **kwargs: Any, ) -> PromptTemplate: @@ -249,7 +249,7 @@ class PromptTemplate(StringPromptTemplate): template: str, *, template_format: str = "f-string", - partial_variables: Optional[Dict[str, Any]] = None, + partial_variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> PromptTemplate: """Load a prompt template from a template. diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 2c08549d0bd..9aea16cf754 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -5,7 +5,7 @@ from __future__ import annotations import warnings from abc import ABC from string import Formatter -from typing import Any, Callable, Dict, List, Set, Tuple, Type +from typing import Any, Callable, Dict from pydantic import BaseModel, create_model @@ -60,7 +60,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: return SandboxedEnvironment().from_string(template).render(**kwargs) -def validate_jinja2(template: str, input_variables: List[str]) -> None: +def validate_jinja2(template: str, input_variables: list[str]) -> None: """ Validate that the input variables are valid for the template. Issues a warning if missing or extra variables are found. @@ -85,7 +85,7 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None: warnings.warn(warning_message.strip(), stacklevel=7) -def _get_jinja2_variables_from_template(template: str) -> Set[str]: +def _get_jinja2_variables_from_template(template: str) -> set[str]: try: from jinja2 import Environment, meta except ImportError as e: @@ -114,7 +114,7 @@ def mustache_formatter(template: str, **kwargs: Any) -> str: def mustache_template_vars( template: str, -) -> Set[str]: +) -> set[str]: """Get the variables from a mustache template. Args: @@ -123,7 +123,7 @@ def mustache_template_vars( Returns: The variables from the template. """ - vars: Set[str] = set() + vars: set[str] = set() section_depth = 0 for type, key in mustache.tokenize(template): if type == "end": @@ -144,7 +144,7 @@ Defs = Dict[str, "Defs"] def mustache_schema( template: str, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Get the variables from a mustache template. Args: @@ -154,8 +154,8 @@ def mustache_schema( The variables from the template as a Pydantic model. """ fields = {} - prefix: Tuple[str, ...] = () - section_stack: List[Tuple[str, ...]] = [] + prefix: tuple[str, ...] = () + section_stack: list[tuple[str, ...]] = [] for type, key in mustache.tokenize(template): if key == ".": continue @@ -178,7 +178,7 @@ def mustache_schema( return _create_model_recursive("PromptInput", defs) -def _create_model_recursive(name: str, defs: Defs) -> Type: +def _create_model_recursive(name: str, defs: Defs) -> type: return create_model( # type: ignore[call-overload] name, **{ @@ -188,20 +188,20 @@ def _create_model_recursive(name: str, defs: Defs) -> Type: ) -DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { +DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = { "f-string": formatter.format, "mustache": mustache_formatter, "jinja2": jinja2_formatter, } -DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { +DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = { "f-string": formatter.validate_input_variables, "jinja2": validate_jinja2, } def check_valid_template( - template: str, template_format: str, input_variables: List[str] + template: str, template_format: str, input_variables: list[str] ) -> None: """Check that template string is valid. @@ -230,7 +230,7 @@ def check_valid_template( ) from exc -def get_template_variables(template: str, template_format: str) -> List[str]: +def get_template_variables(template: str, template_format: str) -> list[str]: """Get the variables from the template. Args: @@ -262,7 +262,7 @@ class StringPromptTemplate(BasePromptTemplate, ABC): """String prompt that exposes the format method, returning a prompt.""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "base"] diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index f02354ce3d2..3d7f2493d7a 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -24,7 +24,7 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional from pydantic import ConfigDict from typing_extensions import TypedDict @@ -132,14 +132,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): _new_arg_supported: bool = False _expects_other_args: bool = False - tags: Optional[List[str]] = None + tags: Optional[list[str]] = None """Optional list of tags associated with the retriever. Defaults to None. These tags will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a retriever with its use case. """ - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None """Optional metadata associated with the retriever. Defaults to None. This metadata will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. @@ -200,7 +200,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: """Invoke the retriever to get relevant documents. Main entry point for synchronous retriever invocations. @@ -263,7 +263,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): input: str, config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Asynchronously invoke the retriever to get relevant documents. Main entry point for asynchronous retriever invocations. @@ -324,7 +324,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): @abstractmethod def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: """Get documents relevant to a query. Args: @@ -336,7 +336,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: """Asynchronously get documents relevant to a query. Args: @@ -358,11 +358,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): query: str, *, callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Retrieve documents relevant to a query. Users should favor using `.invoke` or `.batch` rather than @@ -402,11 +402,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): query: str, *, callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Asynchronously get documents relevant to a query. Users should favor using `.ainvoke` or `.abatch` rather than diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a61f3c79a41..3bcc4b37c5d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -27,8 +27,6 @@ from typing import ( Optional, Protocol, Sequence, - Set, - Tuple, Type, TypeVar, Union, @@ -273,7 +271,7 @@ class Runnable(Generic[Input, Output], ABC): return name_ @property - def InputType(self) -> Type[Input]: + def InputType(self) -> type[Input]: """The type of input this Runnable accepts specified as a type annotation.""" # First loop through all parent classes and if any of them is # a pydantic model, we will pick up the generic parameterization @@ -298,7 +296,7 @@ class Runnable(Generic[Input, Output], ABC): ) @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: """The type of output this Runnable produces specified as a type annotation.""" # First loop through bases -- this will help generic # any pydantic models. @@ -319,13 +317,13 @@ class Runnable(Generic[Input, Output], ABC): ) @property - def input_schema(self) -> Type[BaseModel]: + def input_schema(self) -> type[BaseModel]: """The type of input this Runnable accepts specified as a pydantic model.""" return self.get_input_schema() def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get a pydantic model that can be used to validate input to the Runnable. Runnables that leverage the configurable_fields and configurable_alternatives @@ -360,7 +358,7 @@ class Runnable(Generic[Input, Output], ABC): def get_input_jsonschema( self, config: Optional[RunnableConfig] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get a JSON schema that represents the input to the Runnable. Args: @@ -387,13 +385,13 @@ class Runnable(Generic[Input, Output], ABC): return self.get_input_schema(config).model_json_schema() @property - def output_schema(self) -> Type[BaseModel]: + def output_schema(self) -> type[BaseModel]: """The type of output this Runnable produces specified as a pydantic model.""" return self.get_output_schema() def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get a pydantic model that can be used to validate output to the Runnable. Runnables that leverage the configurable_fields and configurable_alternatives @@ -428,7 +426,7 @@ class Runnable(Generic[Input, Output], ABC): def get_output_jsonschema( self, config: Optional[RunnableConfig] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get a JSON schema that represents the output of the Runnable. Args: @@ -455,13 +453,13 @@ class Runnable(Generic[Input, Output], ABC): return self.get_output_schema(config).model_json_schema() @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: """List configurable fields for this Runnable.""" return [] def config_schema( self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """The type of config this Runnable accepts specified as a pydantic model. To mark a field as configurable, see the `configurable_fields` @@ -509,7 +507,7 @@ class Runnable(Generic[Input, Output], ABC): def get_config_jsonschema( self, *, include: Optional[Sequence[str]] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get a JSON schema that represents the output of the Runnable. Args: @@ -544,7 +542,7 @@ class Runnable(Generic[Input, Output], ABC): def get_prompts( self, config: Optional[RunnableConfig] = None - ) -> List[BasePromptTemplate]: + ) -> list[BasePromptTemplate]: """Return a list of prompts used by this Runnable.""" from langchain_core.prompts.base import BasePromptTemplate @@ -614,7 +612,7 @@ class Runnable(Generic[Input, Output], ABC): """ return RunnableSequence(self, *others, name=name) - def pick(self, keys: Union[str, List[str]]) -> RunnableSerializable[Any, Any]: + def pick(self, keys: Union[str, list[str]]) -> RunnableSerializable[Any, Any]: """Pick keys from the dict output of this Runnable. Pick single key: @@ -670,11 +668,11 @@ class Runnable(Generic[Input, Output], ABC): def assign( self, **kwargs: Union[ - Runnable[Dict[str, Any], Any], - Callable[[Dict[str, Any]], Any], + Runnable[dict[str, Any], Any], + Callable[[dict[str, Any]], Any], Mapping[ str, - Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], + Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]], ], ], ) -> RunnableSerializable[Any, Any]: @@ -710,7 +708,7 @@ class Runnable(Generic[Input, Output], ABC): """ from langchain_core.runnables.passthrough import RunnableAssign - return self | RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs)) + return self | RunnableAssign(RunnableParallel[dict[str, Any]](kwargs)) """ --- Public API --- """ @@ -744,12 +742,12 @@ class Runnable(Generic[Input, Output], ABC): def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: """Default implementation runs invoke in parallel using a thread pool executor. The default implementation of batch works well for IO bound runnables. @@ -786,7 +784,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: Literal[False] = False, **kwargs: Any, - ) -> Iterator[Tuple[int, Output]]: ... + ) -> Iterator[tuple[int, Output]]: ... @overload def batch_as_completed( @@ -796,7 +794,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: Literal[True], **kwargs: Any, - ) -> Iterator[Tuple[int, Union[Output, Exception]]]: ... + ) -> Iterator[tuple[int, Union[Output, Exception]]]: ... def batch_as_completed( self, @@ -805,7 +803,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + ) -> Iterator[tuple[int, Union[Output, Exception]]]: """Run invoke in parallel on a list of inputs, yielding results as they complete.""" @@ -816,7 +814,7 @@ class Runnable(Generic[Input, Output], ABC): def invoke( i: int, input: Input, config: RunnableConfig - ) -> Tuple[int, Union[Output, Exception]]: + ) -> tuple[int, Union[Output, Exception]]: if return_exceptions: try: out: Union[Output, Exception] = self.invoke(input, config, **kwargs) @@ -848,12 +846,12 @@ class Runnable(Generic[Input, Output], ABC): async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: """Default implementation runs ainvoke in parallel using asyncio.gather. The default implementation of batch works well for IO bound runnables. @@ -902,7 +900,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Output]]: ... + ) -> AsyncIterator[tuple[int, Output]]: ... @overload def abatch_as_completed( @@ -912,7 +910,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: Literal[True], **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ... + ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ... async def abatch_as_completed( self, @@ -921,7 +919,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: """Run ainvoke in parallel on a list of inputs, yielding results as they complete. @@ -947,7 +945,7 @@ class Runnable(Generic[Input, Output], ABC): async def ainvoke( i: int, input: Input, config: RunnableConfig - ) -> Tuple[int, Union[Output, Exception]]: + ) -> tuple[int, Union[Output, Exception]]: if return_exceptions: try: out: Union[Output, Exception] = await self.ainvoke( @@ -1699,8 +1697,8 @@ class Runnable(Generic[Input, Output], ABC): def with_types( self, *, - input_type: Optional[Type[Input]] = None, - output_type: Optional[Type[Output]] = None, + input_type: Optional[type[Input]] = None, + output_type: Optional[type[Output]] = None, ) -> Runnable[Input, Output]: """ Bind input and output types to a Runnable, returning a new Runnable. @@ -1722,7 +1720,7 @@ class Runnable(Generic[Input, Output], ABC): def with_retry( self, *, - retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), + retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,), wait_exponential_jitter: bool = True, stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: @@ -1789,7 +1787,7 @@ class Runnable(Generic[Input, Output], ABC): max_attempt_number=stop_after_attempt, ) - def map(self) -> Runnable[List[Input], List[Output]]: + def map(self) -> Runnable[list[Input], list[Output]]: """ Return a new Runnable that maps a list of inputs to a list of outputs, by calling invoke() with each input. @@ -1815,7 +1813,7 @@ class Runnable(Generic[Input, Output], ABC): self, fallbacks: Sequence[Runnable[Input, Output]], *, - exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), + exceptions_to_handle: tuple[type[BaseException], ...] = (Exception,), exception_key: Optional[str] = None, ) -> RunnableWithFallbacksT[Input, Output]: """Add fallbacks to a Runnable, returning a new Runnable. @@ -1893,7 +1891,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, - serialized: Optional[Dict[str, Any]] = None, + serialized: Optional[dict[str, Any]] = None, **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, @@ -1942,7 +1940,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, - serialized: Optional[Dict[str, Any]] = None, + serialized: Optional[dict[str, Any]] = None, **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, @@ -1977,23 +1975,23 @@ class Runnable(Generic[Input, Output], ABC): def _batch_with_config( self, func: Union[ - Callable[[List[Input]], List[Union[Exception, Output]]], + Callable[[list[Input]], list[Union[Exception, Output]]], Callable[ - [List[Input], List[CallbackManagerForChainRun]], - List[Union[Exception, Output]], + [list[Input], list[CallbackManagerForChainRun]], + list[Union[Exception, Output]], ], Callable[ - [List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], - List[Union[Exception, Output]], + [list[Input], list[CallbackManagerForChainRun], list[RunnableConfig]], + list[Union[Exception, Output]], ], ], - input: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + input: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, run_type: Optional[str] = None, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" if not input: @@ -2045,27 +2043,27 @@ class Runnable(Generic[Input, Output], ABC): async def _abatch_with_config( self, func: Union[ - Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], + Callable[[list[Input]], Awaitable[list[Union[Exception, Output]]]], Callable[ - [List[Input], List[AsyncCallbackManagerForChainRun]], - Awaitable[List[Union[Exception, Output]]], + [list[Input], list[AsyncCallbackManagerForChainRun]], + Awaitable[list[Union[Exception, Output]]], ], Callable[ [ - List[Input], - List[AsyncCallbackManagerForChainRun], - List[RunnableConfig], + list[Input], + list[AsyncCallbackManagerForChainRun], + list[RunnableConfig], ], - Awaitable[List[Union[Exception, Output]]], + Awaitable[list[Union[Exception, Output]]], ], ], - input: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + input: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, run_type: Optional[str] = None, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" if not input: @@ -2073,7 +2071,7 @@ class Runnable(Generic[Input, Output], ABC): configs = get_config_list(config, len(input)) callback_managers = [get_async_callback_manager_for_config(c) for c in configs] - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( callback_manager.on_chain_start( None, @@ -2106,7 +2104,7 @@ class Runnable(Generic[Input, Output], ABC): raise else: first_exception: Optional[Exception] = None - coros: List[Awaitable[None]] = [] + coros: list[Awaitable[None]] = [] for run_manager, out in zip(run_managers, output): if isinstance(out, Exception): first_exception = first_exception or out @@ -2333,11 +2331,11 @@ class Runnable(Generic[Input, Output], ABC): @beta_decorator.beta(message="This API is in beta and may change in the future.") def as_tool( self, - args_schema: Optional[Type[BaseModel]] = None, + args_schema: Optional[type[BaseModel]] = None, *, name: Optional[str] = None, description: Optional[str] = None, - arg_types: Optional[Dict[str, Type]] = None, + arg_types: Optional[dict[str, type]] = None, ) -> BaseTool: """Create a BaseTool from a Runnable. @@ -2573,8 +2571,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): def _seq_input_schema( - steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig] -) -> Type[BaseModel]: + steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] +) -> type[BaseModel]: from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick first = steps[0] @@ -2599,8 +2597,8 @@ def _seq_input_schema( def _seq_output_schema( - steps: List[Runnable[Any, Any]], config: Optional[RunnableConfig] -) -> Type[BaseModel]: + steps: list[Runnable[Any, Any]], config: Optional[RunnableConfig] +) -> type[BaseModel]: from langchain_core.runnables.passthrough import RunnableAssign, RunnablePick last = steps[-1] @@ -2730,7 +2728,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): # the last type. first: Runnable[Input, Any] """The first Runnable in the sequence.""" - middle: List[Runnable[Any, Any]] = Field(default_factory=list) + middle: list[Runnable[Any, Any]] = Field(default_factory=list) """The middle Runnables in the sequence.""" last: Runnable[Any, Output] """The last Runnable in the sequence.""" @@ -2740,7 +2738,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): *steps: RunnableLike, name: Optional[str] = None, first: Optional[Runnable[Any, Any]] = None, - middle: Optional[List[Runnable[Any, Any]]] = None, + middle: Optional[list[Runnable[Any, Any]]] = None, last: Optional[Runnable[Any, Any]] = None, ) -> None: """Create a new RunnableSequence. @@ -2755,7 +2753,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): Raises: ValueError: If the sequence has less than 2 steps. """ - steps_flat: List[Runnable] = [] + steps_flat: list[Runnable] = [] if not steps: if first is not None and last is not None: steps_flat = [first] + (middle or []) + [last] @@ -2776,12 +2774,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property - def steps(self) -> List[Runnable[Any, Any]]: + def steps(self) -> list[Runnable[Any, Any]]: """All the Runnables that make up the sequence in order. Returns: @@ -2804,18 +2802,18 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ) @property - def InputType(self) -> Type[Input]: + def InputType(self) -> type[Input]: """The type of the input to the Runnable.""" return self.first.InputType @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: """The type of the output of the Runnable.""" return self.last.OutputType def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get the input schema of the Runnable. Args: @@ -2828,7 +2826,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get the output schema of the Runnable. Args: @@ -2840,7 +2838,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): return _seq_output_schema(self.steps, config) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the config specs of the Runnable. Returns: @@ -2862,8 +2860,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): [tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)], lambda x: x[1], ) - next_deps: Set[str] = set() - deps_by_pos: Dict[int, Set[str]] = {} + next_deps: set[str] = set() + deps_by_pos: dict[int, set[str]] = {} for pos, specs in specs_by_pos: deps_by_pos[pos] = next_deps next_deps = next_deps | {spec[0].id for spec in specs} @@ -3065,12 +3063,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: from langchain_core.beta.runnables.context import config_with_context from langchain_core.callbacks.manager import CallbackManager @@ -3111,7 +3109,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): # Track which inputs (by index) failed so far # If an input has failed it will be present in this map, # and the value will be the exception that was raised. - failed_inputs_map: Dict[int, Exception] = {} + failed_inputs_map: dict[int, Exception] = {} for stepidx, step in enumerate(self.steps): # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) @@ -3191,12 +3189,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]): async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: from langchain_core.beta.runnables.context import aconfig_with_context from langchain_core.callbacks.manager import AsyncCallbackManager @@ -3221,7 +3219,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): for config in configs ] # start the root runs, one per input - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + run_managers: list[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( cm.on_chain_start( None, @@ -3240,7 +3238,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): # Track which inputs (by index) failed so far # If an input has failed it will be present in this map, # and the value will be the exception that was raised. - failed_inputs_map: Dict[int, Exception] = {} + failed_inputs_map: dict[int, Exception] = {} for stepidx, step in enumerate(self.steps): # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) @@ -3305,7 +3303,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): raise else: first_exception: Optional[Exception] = None - coros: List[Awaitable[None]] = [] + coros: list[Awaitable[None]] = [] for run_manager, out in zip(run_managers, inputs): if isinstance(out, Exception): first_exception = first_exception or out @@ -3533,7 +3531,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -3567,7 +3565,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get the input schema of the Runnable. Args: @@ -3596,7 +3594,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get the output schema of the Runnable. Args: @@ -3609,7 +3607,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): return create_model_v2(self.get_name("Output"), field_definitions=fields) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the config specs of the Runnable. Returns: @@ -3662,7 +3660,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: from langchain_core.callbacks.manager import CallbackManager # setup callbacks @@ -3724,7 +3722,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: # setup callbacks config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) @@ -3829,7 +3827,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): input: Iterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( input, self._transform, config, **kwargs ) @@ -3839,7 +3837,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: yield from self.transform(iter([input]), config) async def _atransform( @@ -3898,7 +3896,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( input, self._atransform, config, **kwargs ): @@ -3909,7 +3907,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -4064,7 +4062,7 @@ class RunnableGenerator(Runnable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: # Override the default implementation. # For a runnable generator, we need to bring to provide the # module of the underlying function when creating the model. @@ -4100,7 +4098,7 @@ class RunnableGenerator(Runnable[Input, Output]): def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: # Override the default implementation. # For a runnable generator, we need to bring to provide the # module of the underlying function when creating the model. @@ -4346,7 +4344,7 @@ class RunnableLambda(Runnable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """The pydantic schema for the input to this Runnable. Args: @@ -4414,7 +4412,7 @@ class RunnableLambda(Runnable[Input, Output]): def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: # Override the default implementation. # For a runnable lambda, we need to bring to provide the # module of the underlying function when creating the model. @@ -4435,7 +4433,7 @@ class RunnableLambda(Runnable[Input, Output]): ) @property - def deps(self) -> List[Runnable]: + def deps(self) -> list[Runnable]: """The dependencies of this Runnable. Returns: @@ -4449,7 +4447,7 @@ class RunnableLambda(Runnable[Input, Output]): else: objects = [] - deps: List[Runnable] = [] + deps: list[Runnable] = [] for obj in objects: if isinstance(obj, Runnable): deps.append(obj) @@ -4458,7 +4456,7 @@ class RunnableLambda(Runnable[Input, Output]): return deps @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec for dep in self.deps for spec in dep.config_specs ) @@ -4944,7 +4942,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: return create_model_v2( self.get_name("Input"), root=( @@ -4962,12 +4960,12 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): ) @property - def OutputType(self) -> Type[List[Output]]: + def OutputType(self) -> type[list[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: schema = self.bound.get_output_schema(config) return create_model_v2( self.get_name("Output"), @@ -4983,7 +4981,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): ) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return self.bound.config_specs def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: @@ -4994,42 +4992,42 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] def _invoke( self, - inputs: List[Input], + inputs: list[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> List[Output]: + ) -> list[Output]: configs = [ patch_config(config, callbacks=run_manager.get_child()) for _ in inputs ] return self.bound.batch(inputs, configs, **kwargs) def invoke( - self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> List[Output]: + self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> list[Output]: return self._call_with_config(self._invoke, input, config, **kwargs) async def _ainvoke( self, - inputs: List[Input], + inputs: list[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> List[Output]: + ) -> list[Output]: configs = [ patch_config(config, callbacks=run_manager.get_child()) for _ in inputs ] return await self.bound.abatch(inputs, configs, **kwargs) async def ainvoke( - self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> List[Output]: + self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> list[Output]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) async def astream_events( @@ -5074,7 +5072,7 @@ class RunnableEach(RunnableEachBase[Input, Output]): """ @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -5181,7 +5179,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): config: RunnableConfig = Field(default_factory=dict) """The config to bind to the underlying Runnable.""" - config_factories: List[Callable[[RunnableConfig], RunnableConfig]] = Field( + config_factories: list[Callable[[RunnableConfig], RunnableConfig]] = Field( default_factory=list ) """The config factories to bind to the underlying Runnable.""" @@ -5210,10 +5208,10 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): kwargs: Optional[Mapping[str, Any]] = None, config: Optional[RunnableConfig] = None, config_factories: Optional[ - List[Callable[[RunnableConfig], RunnableConfig]] + list[Callable[[RunnableConfig], RunnableConfig]] ] = None, - custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, - custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, + custom_input_type: Optional[Union[type[Input], BaseModel]] = None, + custom_output_type: Optional[Union[type[Output], BaseModel]] = None, **other_kwargs: Any, ) -> None: """Create a RunnableBinding from a Runnable and kwargs. @@ -5255,7 +5253,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): return self.bound.get_name(suffix, name=name) @property - def InputType(self) -> Type[Input]: + def InputType(self) -> type[Input]: return ( cast(Type[Input], self.custom_input_type) if self.custom_input_type is not None @@ -5263,7 +5261,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): ) @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: return ( cast(Type[Output], self.custom_output_type) if self.custom_output_type is not None @@ -5272,20 +5270,20 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: if self.custom_input_type is not None: return super().get_input_schema(config) return self.bound.get_input_schema(merge_configs(self.config, config)) def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: if self.custom_output_type is not None: return super().get_output_schema(config) return self.bound.get_output_schema(merge_configs(self.config, config)) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return self.bound.config_specs def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: @@ -5296,7 +5294,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -5330,12 +5328,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: if isinstance(config, list): configs = cast( List[RunnableConfig], @@ -5352,12 +5350,12 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: if isinstance(config, list): configs = cast( List[RunnableConfig], @@ -5380,7 +5378,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: Literal[False] = False, **kwargs: Any, - ) -> Iterator[Tuple[int, Output]]: ... + ) -> Iterator[tuple[int, Output]]: ... @overload def batch_as_completed( @@ -5390,7 +5388,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: Literal[True], **kwargs: Any, - ) -> Iterator[Tuple[int, Union[Output, Exception]]]: ... + ) -> Iterator[tuple[int, Union[Output, Exception]]]: ... def batch_as_completed( self, @@ -5399,7 +5397,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + ) -> Iterator[tuple[int, Union[Output, Exception]]]: if isinstance(config, Sequence): configs = cast( List[RunnableConfig], @@ -5431,7 +5429,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: Literal[False] = False, **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Output]]: ... + ) -> AsyncIterator[tuple[int, Output]]: ... @overload def abatch_as_completed( @@ -5441,7 +5439,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: Literal[True], **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: ... + ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ... async def abatch_as_completed( self, @@ -5450,7 +5448,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: if isinstance(config, Sequence): configs = cast( List[RunnableConfig], @@ -5590,7 +5588,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): """ @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -5678,8 +5676,8 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): def with_types( self, - input_type: Optional[Union[Type[Input], BaseModel]] = None, - output_type: Optional[Union[Type[Output], BaseModel]] = None, + input_type: Optional[Union[type[Input], BaseModel]] = None, + output_type: Optional[Union[type[Output], BaseModel]] = None, ) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 31cb635fdb8..69e58f194ae 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -12,7 +12,6 @@ from typing import ( Any, Awaitable, Callable, - Dict, Generator, Iterable, Iterator, @@ -56,13 +55,13 @@ class EmptyDict(TypedDict, total=False): class RunnableConfig(TypedDict, total=False): """Configuration for a Runnable.""" - tags: List[str] + tags: list[str] """ Tags for this call and any sub-calls (eg. a Chain calling an LLM). You can use these to filter calls. """ - metadata: Dict[str, Any] + metadata: dict[str, Any] """ Metadata for this call and any sub-calls (eg. a Chain calling an LLM). Keys should be strings, values should be JSON-serializable. @@ -90,7 +89,7 @@ class RunnableConfig(TypedDict, total=False): Maximum number of times a call can recurse. If not provided, defaults to 25. """ - configurable: Dict[str, Any] + configurable: dict[str, Any] """ Runtime values for attributes previously made configurable on this Runnable, or sub-Runnables, through .configurable_fields() or .configurable_alternatives(). @@ -205,7 +204,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def get_config_list( config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]], length: int -) -> List[RunnableConfig]: +) -> list[RunnableConfig]: """Get a list of configs from a single config or a list of configs. It is useful for subclasses overriding batch() or abatch(). @@ -255,7 +254,7 @@ def patch_config( recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, - configurable: Optional[Dict[str, Any]] = None, + configurable: Optional[dict[str, Any]] = None, ) -> RunnableConfig: """Patch a config with new values. diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index fdcd7d0e959..7f0c36687a6 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -8,12 +8,10 @@ from typing import ( Any, AsyncIterator, Callable, - Dict, Iterator, List, Optional, Sequence, - Tuple, Type, Union, cast, @@ -69,27 +67,27 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property - def InputType(self) -> Type[Input]: + def InputType(self) -> type[Input]: return self.default.InputType @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: return self.default.OutputType def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_input_schema(config) def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_output_schema(config) @@ -109,7 +107,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def prepare( self, config: Optional[RunnableConfig] = None - ) -> Tuple[Runnable[Input, Output], RunnableConfig]: + ) -> tuple[Runnable[Input, Output], RunnableConfig]: """Prepare the Runnable for invocation. Args: @@ -127,7 +125,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): @abstractmethod def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Tuple[Runnable[Input, Output], RunnableConfig]: ... + ) -> tuple[Runnable[Input, Output], RunnableConfig]: ... def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -143,12 +141,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def batch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: configs = get_config_list(config, len(inputs)) prepared = [self.prepare(c) for c in configs] @@ -164,7 +162,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return [] def invoke( - prepared: Tuple[Runnable[Input, Output], RunnableConfig], + prepared: tuple[Runnable[Input, Output], RunnableConfig], input: Input, ) -> Union[Output, Exception]: bound, config = prepared @@ -185,12 +183,12 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): async def abatch( self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[Input], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: configs = get_config_list(config, len(inputs)) prepared = [self.prepare(c) for c in configs] @@ -206,7 +204,7 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): return [] async def ainvoke( - prepared: Tuple[Runnable[Input, Output], RunnableConfig], + prepared: tuple[Runnable[Input, Output], RunnableConfig], input: Input, ) -> Union[Output, Exception]: bound, config = prepared @@ -362,15 +360,15 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): ) """ - fields: Dict[str, AnyConfigurableField] + fields: dict[str, AnyConfigurableField] @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the configuration specs for the RunnableConfigurableFields. Returns: @@ -412,7 +410,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Tuple[Runnable[Input, Output], RunnableConfig]: + ) -> tuple[Runnable[Input, Output], RunnableConfig]: config = ensure_config(config) specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} configurable_fields = { @@ -467,7 +465,7 @@ _enums_for_spec: WeakValueDictionary[ Union[ ConfigurableFieldSingleOption, ConfigurableFieldMultiOption, ConfigurableField ], - Type[StrEnum], + type[StrEnum], ] = WeakValueDictionary() _enums_for_spec_lock = threading.Lock() @@ -532,7 +530,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): which: ConfigurableField """The ConfigurableField to use to choose between alternatives.""" - alternatives: Dict[ + alternatives: dict[ str, Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], ] @@ -547,12 +545,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): the alternative named "gpt3" becomes "model==gpt3/temperature".""" @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: with _enums_for_spec_lock: if which_enum := _enums_for_spec.get(self.which): pass @@ -612,7 +610,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): def _prepare( self, config: Optional[RunnableConfig] = None - ) -> Tuple[Runnable[Input, Output], RunnableConfig]: + ) -> tuple[Runnable[Input, Output], RunnableConfig]: config = ensure_config(config) which = config.get("configurable", {}).get(self.which.id, self.default_key) # remap configurable keys for the chosen alternative diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index d7db269e66d..bcf64c95fc1 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -8,14 +8,10 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, - List, NamedTuple, Optional, Protocol, Sequence, - Tuple, - Type, TypedDict, Union, overload, @@ -106,8 +102,8 @@ class Node(NamedTuple): id: str name: str - data: Union[Type[BaseModel], RunnableType] - metadata: Optional[Dict[str, Any]] + data: Union[type[BaseModel], RunnableType] + metadata: Optional[dict[str, Any]] def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node: """Return a copy of the node with optional new id and name. @@ -179,7 +175,7 @@ class MermaidDrawMethod(Enum): API = "api" # Uses Mermaid.INK API to render the graph -def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str: +def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str: """Convert the data of a node to a string. Args: @@ -202,7 +198,7 @@ def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str: def node_data_json( node: Node, *, with_schemas: bool = False -) -> Dict[str, Union[str, Dict[str, Any]]]: +) -> dict[str, Union[str, dict[str, Any]]]: """Convert the data of a node to a JSON-serializable format. Args: @@ -217,7 +213,7 @@ def node_data_json( from langchain_core.runnables.base import Runnable, RunnableSerializable if isinstance(node.data, RunnableSerializable): - json: Dict[str, Any] = { + json: dict[str, Any] = { "type": "runnable", "data": { "id": node.data.lc_id(), @@ -265,10 +261,10 @@ class Graph: edges: List of edges in the graph. Defaults to an empty list. """ - nodes: Dict[str, Node] = field(default_factory=dict) - edges: List[Edge] = field(default_factory=list) + nodes: dict[str, Node] = field(default_factory=dict) + edges: list[Edge] = field(default_factory=list) - def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]: + def to_json(self, *, with_schemas: bool = False) -> dict[str, list[dict[str, Any]]]: """Convert the graph to a JSON-serializable format. Args: @@ -282,7 +278,7 @@ class Graph: node.id: i if is_uuid(node.id) else node.id for i, node in enumerate(self.nodes.values()) } - edges: List[Dict[str, Any]] = [] + edges: list[dict[str, Any]] = [] for edge in self.edges: edge_dict = { "source": stable_node_ids[edge.source], @@ -315,10 +311,10 @@ class Graph: def add_node( self, - data: Union[Type[BaseModel], RunnableType], + data: Union[type[BaseModel], RunnableType], id: Optional[str] = None, *, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> Node: """Add a node to the graph and return it. @@ -386,7 +382,7 @@ class Graph: def extend( self, graph: Graph, *, prefix: str = "" - ) -> Tuple[Optional[Node], Optional[Node]]: + ) -> tuple[Optional[Node], Optional[Node]]: """Add all nodes and edges from another graph. Note this doesn't check for duplicates, nor does it connect the graphs. @@ -622,7 +618,7 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: If there is no such node, or there are multiple, return None. When drawing the graph, this node would be the origin.""" targets = {edge.target for edge in graph.edges if edge.source not in exclude} - found: List[Node] = [] + found: list[Node] = [] for node in graph.nodes.values(): if node.id not in exclude and node.id not in targets: found.append(node) @@ -635,7 +631,7 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: If there is no such node, or there are multiple, return None. When drawing the graph, this node would be the destination.""" sources = {edge.source for edge in graph.edges if edge.target not in exclude} - found: List[Node] = [] + found: list[Node] = [] for node in graph.nodes.values(): if node.id not in exclude and node.id not in sources: found.append(node) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index da297ec3bca..3d405bebefa 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -6,10 +6,8 @@ from typing import ( Any, Callable, Dict, - List, Optional, Sequence, - Type, Union, ) @@ -238,7 +236,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): history_factory_config: Sequence[ConfigurableFieldSpec] @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -366,7 +364,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): self._history_chain = history_chain @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the configuration specs for the RunnableWithMessageHistory.""" return get_unique_config_specs( super().config_specs + list(self.history_factory_config) @@ -374,10 +372,10 @@ class RunnableWithMessageHistory(RunnableBindingBase): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: from langchain_core.messages import BaseMessage - fields: Dict = {} + fields: dict = {} if self.input_messages_key and self.history_messages_key: fields[self.input_messages_key] = ( Union[str, BaseMessage, Sequence[BaseMessage]], @@ -398,13 +396,13 @@ class RunnableWithMessageHistory(RunnableBindingBase): ) @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> type[Output]: output_type = self._history_chain.OutputType return output_type def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Get a pydantic model that can be used to validate output to the Runnable. Runnables that leverage the configurable_fields and configurable_alternatives @@ -430,15 +428,15 @@ class RunnableWithMessageHistory(RunnableBindingBase): module_name=self.__class__.__module__, ) - def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: + def _is_not_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool: return False - async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: + async def _is_async(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> bool: return True def _get_input_messages( self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] - ) -> List[BaseMessage]: + ) -> list[BaseMessage]: from langchain_core.messages import BaseMessage # If dictionary, try to pluck the single key representing messages @@ -481,7 +479,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): def _get_output_messages( self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] - ) -> List[BaseMessage]: + ) -> list[BaseMessage]: from langchain_core.messages import BaseMessage # If dictionary, try to pluck the single key representing messages @@ -514,7 +512,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): f"Got {output_val}." ) - def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: + def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = hist.messages.copy() @@ -527,8 +525,8 @@ class RunnableWithMessageHistory(RunnableBindingBase): return messages async def _aenter_history( - self, input: Dict[str, Any], config: RunnableConfig - ) -> List[BaseMessage]: + self, input: dict[str, Any], config: RunnableConfig + ) -> list[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = (await hist.aget_messages()).copy() @@ -621,7 +619,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): return config -def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]: +def _get_parameter_names(callable_: GetSessionHistoryCallable) -> list[str]: """Get the parameter names of the callable.""" sig = inspect.signature(callable_) return list(sig.parameters.keys()) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index a613533674d..3295374f6a2 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -13,10 +13,8 @@ from typing import ( Callable, Dict, Iterator, - List, Mapping, Optional, - Type, Union, cast, ) @@ -144,7 +142,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): # {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20} """ - input_type: Optional[Type[Other]] = None + input_type: Optional[type[Other]] = None func: Optional[ Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] @@ -180,7 +178,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): ] ] = None, *, - input_type: Optional[Type[Other]] = None, + input_type: Optional[type[Other]] = None, **kwargs: Any, ) -> None: if inspect.iscoroutinefunction(func): @@ -194,7 +192,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -210,11 +208,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): def assign( cls, **kwargs: Union[ - Runnable[Dict[str, Any], Any], - Callable[[Dict[str, Any]], Any], + Runnable[dict[str, Any], Any], + Callable[[dict[str, Any]], Any], Mapping[ str, - Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]], + Union[Runnable[dict[str, Any], Any], Callable[[dict[str, Any]], Any]], ], ], ) -> RunnableAssign: @@ -228,7 +226,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): A Runnable that merges the Dict input with the output produced by the mapping argument. """ - return RunnableAssign(RunnableParallel[Dict[str, Any]](kwargs)) + return RunnableAssign(RunnableParallel[dict[str, Any]](kwargs)) def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any @@ -392,9 +390,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): # returns {'input': 5, 'add_step': {'added': 15}} """ - mapper: RunnableParallel[Dict[str, Any]] + mapper: RunnableParallel[dict[str, Any]] - def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None: + def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None: super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] @classmethod @@ -402,7 +400,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -418,7 +416,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: map_input_schema = self.mapper.get_input_schema(config) if not issubclass(map_input_schema, RootModel): # ie. it's a dict @@ -428,7 +426,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def get_output_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: map_input_schema = self.mapper.get_input_schema(config) map_output_schema = self.mapper.get_output_schema(config) if not issubclass(map_input_schema, RootModel) and not issubclass( @@ -453,7 +451,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return super().get_output_schema(config) @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return self.mapper.config_specs def get_graph(self, config: RunnableConfig | None = None) -> Graph: @@ -470,11 +468,11 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def _invoke( self, - input: Dict[str, Any], + input: dict[str, Any], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: assert isinstance( input, dict ), "The input to RunnablePassthrough.assign() must be a dict." @@ -490,19 +488,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def invoke( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return self._call_with_config(self._invoke, input, config, **kwargs) async def _ainvoke( self, - input: Dict[str, Any], + input: dict[str, Any], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: assert isinstance( input, dict ), "The input to RunnablePassthrough.assign() must be a dict." @@ -518,19 +516,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): async def ainvoke( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) def _transform( self, - input: Iterator[Dict[str, Any]], + input: Iterator[dict[str, Any]], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: # collect mapper keys mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough @@ -572,21 +570,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def transform( self, - input: Iterator[Dict[str, Any]], + input: Iterator[dict[str, Any]], config: Optional[RunnableConfig] = None, **kwargs: Any | None, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( input, self._transform, config, **kwargs ) async def _atransform( self, - input: AsyncIterator[Dict[str, Any]], + input: AsyncIterator[dict[str, Any]], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: # collect mapper keys mapper_keys = set(self.mapper.steps__.keys()) # create two streams, one for the map and one for the passthrough @@ -622,10 +620,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): async def atransform( self, - input: AsyncIterator[Dict[str, Any]], + input: AsyncIterator[dict[str, Any]], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( input, self._atransform, config, **kwargs ): @@ -633,19 +631,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def stream( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) async def astream( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: - async def input_aiter() -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: + async def input_aiter() -> AsyncIterator[dict[str, Any]]: yield input async for chunk in self.atransform(input_aiter(), config, **kwargs): @@ -683,9 +681,9 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): print(output_data) # Output: {'name': 'John', 'age': 30} """ - keys: Union[str, List[str]] + keys: Union[str, list[str]] - def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None: + def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None: super().__init__(keys=keys, **kwargs) # type: ignore[call-arg] @classmethod @@ -693,7 +691,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -707,7 +705,7 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): ) return super().get_name(suffix, name=name) - def _pick(self, input: Dict[str, Any]) -> Any: + def _pick(self, input: dict[str, Any]) -> Any: assert isinstance( input, dict ), "The input to RunnablePassthrough.assign() must be a dict." @@ -723,36 +721,36 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def _invoke( self, - input: Dict[str, Any], - ) -> Dict[str, Any]: + input: dict[str, Any], + ) -> dict[str, Any]: return self._pick(input) def invoke( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return self._call_with_config(self._invoke, input, config, **kwargs) async def _ainvoke( self, - input: Dict[str, Any], - ) -> Dict[str, Any]: + input: dict[str, Any], + ) -> dict[str, Any]: return self._pick(input) async def ainvoke( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) def _transform( self, - input: Iterator[Dict[str, Any]], - ) -> Iterator[Dict[str, Any]]: + input: Iterator[dict[str, Any]], + ) -> Iterator[dict[str, Any]]: for chunk in input: picked = self._pick(chunk) if picked is not None: @@ -760,18 +758,18 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def transform( self, - input: Iterator[Dict[str, Any]], + input: Iterator[dict[str, Any]], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: yield from self._transform_stream_with_config( input, self._transform, config, **kwargs ) async def _atransform( self, - input: AsyncIterator[Dict[str, Any]], - ) -> AsyncIterator[Dict[str, Any]]: + input: AsyncIterator[dict[str, Any]], + ) -> AsyncIterator[dict[str, Any]]: async for chunk in input: picked = self._pick(chunk) if picked is not None: @@ -779,10 +777,10 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): async def atransform( self, - input: AsyncIterator[Dict[str, Any]], + input: AsyncIterator[dict[str, Any]], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: async for chunk in self._atransform_stream_with_config( input, self._atransform, config, **kwargs ): @@ -790,19 +788,19 @@ class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def stream( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> Iterator[Dict[str, Any]]: + ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) async def astream( self, - input: Dict[str, Any], + input: dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: - async def input_aiter() -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: + async def input_aiter() -> AsyncIterator[dict[str, Any]]: yield input async for chunk in self.atransform(input_aiter(), config, **kwargs): diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index b8a5faaaa87..8b0691e9e0c 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -71,7 +71,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): runnables: Mapping[str, Runnable[Any, Output]] @property - def config_specs(self) -> List[ConfigurableFieldSpec]: + def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec for step in self.runnables.values() for spec in step.config_specs ) @@ -94,7 +94,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): return True @classmethod - def get_lc_namespace(cls) -> List[str]: + def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @@ -125,12 +125,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): def batch( self, - inputs: List[RouterInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[RouterInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: if not inputs: return [] @@ -160,12 +160,12 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): async def abatch( self, - inputs: List[RouterInput], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + inputs: list[RouterInput], + config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], - ) -> List[Output]: + ) -> list[Output]: if not inputs: return [] diff --git a/libs/core/langchain_core/runnables/schema.py b/libs/core/langchain_core/runnables/schema.py index c5ea213a6a2..a13d8f6e4db 100644 --- a/libs/core/langchain_core/runnables/schema.py +++ b/libs/core/langchain_core/runnables/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Sequence, Union +from typing import Any, Literal, Sequence, Union from typing_extensions import NotRequired, TypedDict @@ -110,7 +110,7 @@ class BaseStreamEvent(TypedDict): Each child Runnable that gets invoked as part of the execution of a parent Runnable is assigned its own unique ID. """ - tags: NotRequired[List[str]] + tags: NotRequired[list[str]] """Tags associated with the Runnable that generated this event. Tags are always inherited from parent Runnables. @@ -118,7 +118,7 @@ class BaseStreamEvent(TypedDict): Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})` or passed at run time using `.astream_events(..., {"tags": ["hello"]})`. """ - metadata: NotRequired[Dict[str, Any]] + metadata: NotRequired[dict[str, Any]] """Metadata associated with the Runnable that generated this event. Metadata can either be bound to a Runnable using diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 20d82be3fd6..4b1763500e8 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -18,13 +18,11 @@ from typing import ( Coroutine, Dict, Iterable, - List, Mapping, NamedTuple, Optional, Protocol, Sequence, - Set, TypeVar, Union, ) @@ -126,7 +124,7 @@ def asyncio_accepts_context() -> bool: class IsLocalDict(ast.NodeVisitor): """Check if a name is a local dict.""" - def __init__(self, name: str, keys: Set[str]) -> None: + def __init__(self, name: str, keys: set[str]) -> None: """Initialize the visitor. Args: @@ -181,7 +179,7 @@ class IsFunctionArgDict(ast.NodeVisitor): """Check if the first argument of a function is a dict.""" def __init__(self) -> None: - self.keys: Set[str] = set() + self.keys: set[str] = set() def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -230,8 +228,8 @@ class NonLocals(ast.NodeVisitor): """Get nonlocal variables accessed.""" def __init__(self) -> None: - self.loads: Set[str] = set() - self.stores: Set[str] = set() + self.loads: set[str] = set() + self.stores: set[str] = set() def visit_Name(self, node: ast.Name) -> Any: """Visit a name node. @@ -271,7 +269,7 @@ class FunctionNonLocals(ast.NodeVisitor): """Get the nonlocal variables accessed of a function.""" def __init__(self) -> None: - self.nonlocals: Set[str] = set() + self.nonlocals: set[str] = set() def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -335,7 +333,7 @@ class GetLambdaSource(ast.NodeVisitor): self.source = ast.unparse(node) -def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: +def get_function_first_arg_dict_keys(func: Callable) -> Optional[list[str]]: """Get the keys of the first argument of a function if it is a dict. Args: @@ -378,7 +376,7 @@ def get_lambda_source(func: Callable) -> Optional[str]: return name -def get_function_nonlocals(func: Callable) -> List[Any]: +def get_function_nonlocals(func: Callable) -> list[Any]: """Get the nonlocal variables accessed by a function. Args: @@ -392,7 +390,7 @@ def get_function_nonlocals(func: Callable) -> List[Any]: tree = ast.parse(textwrap.dedent(code)) visitor = FunctionNonLocals() visitor.visit(tree) - values: List[Any] = [] + values: list[Any] = [] closure = inspect.getclosurevars(func) candidates = {**closure.globals, **closure.nonlocals} for k, v in candidates.items(): @@ -608,12 +606,12 @@ class ConfigurableFieldSpec(NamedTuple): description: Optional[str] = None default: Any = None is_shared: bool = False - dependencies: Optional[List[str]] = None + dependencies: Optional[list[str]] = None def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], -) -> List[ConfigurableFieldSpec]: +) -> list[ConfigurableFieldSpec]: """Get the unique config specs from a sequence of config specs. Args: @@ -628,7 +626,7 @@ def get_unique_config_specs( grouped = groupby( sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id ) - unique: List[ConfigurableFieldSpec] = [] + unique: list[ConfigurableFieldSpec] = [] for id, dupes in grouped: first = next(dupes) others = list(dupes) diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index fcde3be9b27..0e58c88dcc4 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -4,7 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union from pydantic import BaseModel @@ -142,10 +142,10 @@ class Operation(FilterDirective): """ operator: Operator - arguments: List[FilterDirective] + arguments: list[FilterDirective] def __init__( - self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any + self, operator: Operator, arguments: list[FilterDirective], **kwargs: Any ) -> None: # super exists from BaseModel super().__init__( # type: ignore[call-arg] diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f528a80e565..38b54cb7e9a 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -13,12 +13,9 @@ from typing import ( Any, Callable, Dict, - List, Literal, Optional, Sequence, - Tuple, - Type, TypeVar, Union, cast, @@ -78,11 +75,11 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" -def _is_annotated_type(typ: Type[Any]) -> bool: +def _is_annotated_type(typ: type[Any]) -> bool: return get_origin(typ) is Annotated -def _get_annotation_description(arg_type: Type) -> str | None: +def _get_annotation_description(arg_type: type) -> str | None: if _is_annotated_type(arg_type): annotated_args = get_args(arg_type) for annotation in annotated_args[1:]: @@ -92,7 +89,7 @@ def _get_annotation_description(arg_type: Type) -> str | None: def _get_filtered_args( - inferred_model: Type[BaseModel], + inferred_model: type[BaseModel], func: Callable, *, filter_args: Sequence[str], @@ -112,7 +109,7 @@ def _get_filtered_args( def _parse_python_function_docstring( function: Callable, annotations: dict, error_on_invalid_docstring: bool = False -) -> Tuple[str, dict]: +) -> tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. Assumes the function docstring follows Google Python style guide. @@ -141,7 +138,7 @@ def _infer_arg_descriptions( *, parse_docstring: bool = False, error_on_invalid_docstring: bool = False, -) -> Tuple[str, dict]: +) -> tuple[str, dict]: """Infer argument descriptions from a function's docstring.""" if hasattr(inspect, "get_annotations"): # This is for python < 3.10 @@ -218,7 +215,7 @@ def create_schema_from_function( parse_docstring: bool = False, error_on_invalid_docstring: bool = False, include_injected: bool = True, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create a pydantic schema from a function's signature. Args: @@ -273,7 +270,7 @@ def create_schema_from_function( filter_args_ = filter_args else: # Handle classmethods and instance methods - existing_params: List[str] = list(sig.parameters.keys()) + existing_params: list[str] = list(sig.parameters.keys()) if existing_params and existing_params[0] in ("self", "cls") and in_class: filter_args_ = [existing_params[0]] + list(FILTERED_ARGS) else: @@ -395,13 +392,13 @@ class ChildTool(BaseTool): description="Callback manager to add to the run trace.", ) ) - tags: Optional[List[str]] = None + tags: Optional[list[str]] = None """Optional list of tags associated with the tool. Defaults to None. These tags will be associated with each call to this tool, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a tool with its use case. """ - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None """Optional metadata associated with the tool. Defaults to None. This metadata will be associated with each call to this tool, and passed as arguments to the handlers defined in `callbacks`. @@ -451,7 +448,7 @@ class ChildTool(BaseTool): return self.get_input_schema().model_json_schema()["properties"] @property - def tool_call_schema(self) -> Type[BaseModel]: + def tool_call_schema(self) -> type[BaseModel]: full_schema = self.get_input_schema() fields = [] for name, type_ in _get_all_basemodel_annotations(full_schema).items(): @@ -465,7 +462,7 @@ class ChildTool(BaseTool): def get_input_schema( self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """The tool's input schema. Args: @@ -481,7 +478,7 @@ class ChildTool(BaseTool): def invoke( self, - input: Union[str, Dict, ToolCall], + input: Union[str, dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -490,7 +487,7 @@ class ChildTool(BaseTool): async def ainvoke( self, - input: Union[str, Dict, ToolCall], + input: Union[str, dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -499,7 +496,7 @@ class ChildTool(BaseTool): # --- Tool --- - def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]: + def _parse_input(self, tool_input: Union[str, dict]) -> Union[str, dict[str, Any]]: """Convert tool input to a pydantic model. Args: @@ -536,7 +533,7 @@ class ChildTool(BaseTool): @model_validator(mode="before") @classmethod - def raise_deprecation(cls, values: Dict) -> Any: + def raise_deprecation(cls, values: dict) -> Any: """Raise deprecation warning if callback_manager is used. Args: @@ -574,7 +571,7 @@ class ChildTool(BaseTool): kwargs["run_manager"] = kwargs["run_manager"].get_sync() return await run_in_executor(None, self._run, *args, **kwargs) - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]: tool_input = self._parse_input(tool_input) # For backwards compatibility, if run_input is a string, # pass as a positional argument. @@ -585,14 +582,14 @@ class ChildTool(BaseTool): def run( self, - tool_input: Union[str, Dict[str, Any]], + tool_input: Union[str, dict[str, Any]], verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, config: Optional[RunnableConfig] = None, @@ -696,14 +693,14 @@ class ChildTool(BaseTool): async def arun( self, - tool_input: Union[str, Dict], + tool_input: Union[str, dict], verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", callbacks: Callbacks = None, *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, config: Optional[RunnableConfig] = None, @@ -866,7 +863,7 @@ def _prep_run_args( input: Union[str, dict, ToolCall], config: Optional[RunnableConfig], **kwargs: Any, -) -> Tuple[Union[str, Dict], Dict]: +) -> tuple[Union[str, dict], dict]: config = ensure_config(config) if _is_tool_call(input): tool_call_id: Optional[str] = cast(ToolCall, input)["id"] @@ -933,7 +930,7 @@ def _stringify(content: Any) -> str: return str(content) -def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]: +def _get_type_hints(func: Callable) -> Optional[dict[str, type]]: if isinstance(func, functools.partial): func = func.func try: @@ -956,7 +953,7 @@ class InjectedToolArg: """Annotation for a Tool arg that is **not** meant to be generated by a model.""" -def _is_injected_arg_type(type_: Type) -> bool: +def _is_injected_arg_type(type_: type) -> bool: return any( isinstance(arg, InjectedToolArg) or (isinstance(arg, type) and issubclass(arg, InjectedToolArg)) @@ -966,10 +963,10 @@ def _is_injected_arg_type(type_: Type) -> bool: def _get_all_basemodel_annotations( cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True -) -> Dict[str, Type]: +) -> dict[str, type]: # cls has no subscript: cls = FooBar if isinstance(cls, type): - annotations: Dict[str, Type] = {} + annotations: dict[str, type] = {} for name, param in inspect.signature(cls).parameters.items(): # Exclude hidden init args added by pydantic Config. For example if # BaseModel(extra="allow") then "extra_data" will part of init sig. @@ -979,7 +976,7 @@ def _get_all_basemodel_annotations( ) and name not in fields: continue annotations[name] = param.annotation - orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple()) + orig_bases: tuple = getattr(cls, "__orig_bases__", tuple()) # cls has subscript: cls = FooBar[int] else: annotations = _get_all_basemodel_annotations( @@ -1011,7 +1008,7 @@ def _get_all_basemodel_annotations( # parent_origin = Baz, # generic_type_vars = (type vars in Baz) # generic_map = {type var in Baz: str} - generic_type_vars: Tuple = getattr(parent_origin, "__parameters__", tuple()) + generic_type_vars: tuple = getattr(parent_origin, "__parameters__", tuple()) generic_map = { type_var: t for type_var, t in zip(generic_type_vars, get_args(parent)) } @@ -1027,10 +1024,10 @@ def _get_all_basemodel_annotations( def _replace_type_vars( - type_: Type, - generic_map: Optional[Dict[TypeVar, Type]] = None, + type_: type, + generic_map: Optional[dict[TypeVar, type]] = None, default_to_bound: bool = True, -) -> Type: +) -> type: generic_map = generic_map or {} if isinstance(type_, TypeVar): if type_ in generic_map: @@ -1043,7 +1040,7 @@ def _replace_type_vars( new_args = tuple( _replace_type_vars(arg, generic_map, default_to_bound) for arg in args ) - return _py_38_safe_origin(origin)[new_args] + return _py_38_safe_origin(origin)[new_args] # type: ignore[index] else: return type_ @@ -1052,5 +1049,5 @@ class BaseToolkit(BaseModel, ABC): """Base Toolkit representing a collection of related tools.""" @abstractmethod - def get_tools(self) -> List[BaseTool]: + def get_tools(self) -> list[BaseTool]: """Get the tools in the toolkit.""" diff --git a/libs/core/langchain_core/tools/render.py b/libs/core/langchain_core/tools/render.py index 6d93c3fc2cb..e8c05a06402 100644 --- a/libs/core/langchain_core/tools/render.py +++ b/libs/core/langchain_core/tools/render.py @@ -8,7 +8,7 @@ from langchain_core.tools.base import BaseTool ToolsRenderer = Callable[[List[BaseTool]], str] -def render_text_description(tools: List[BaseTool]) -> str: +def render_text_description(tools: list[BaseTool]) -> str: """Render the tool name and description in plain text. Args: @@ -36,7 +36,7 @@ def render_text_description(tools: List[BaseTool]) -> str: return "\n".join(descriptions) -def render_text_description_and_args(tools: List[BaseTool]) -> str: +def render_text_description_and_args(tools: list[BaseTool]) -> str: """Render the tool name, description, and args in plain text. Args: diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index 94be27bb895..cace7e82ec7 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -5,10 +5,7 @@ from typing import ( Any, Awaitable, Callable, - Dict, Optional, - Tuple, - Type, Union, ) @@ -40,7 +37,7 @@ class Tool(BaseTool): async def ainvoke( self, - input: Union[str, Dict, ToolCall], + input: Union[str, dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -65,7 +62,7 @@ class Tool(BaseTool): # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, tool_input: Union[str, dict]) -> tuple[tuple, dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) # For backwards compatibility. The tool must be run with a single input @@ -131,7 +128,7 @@ class Tool(BaseTool): name: str, # We keep these required to support backwards compatibility description: str, return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, + args_schema: Optional[type[BaseModel]] = None, coroutine: Optional[ Callable[..., Awaitable[Any]] ] = None, # This is last for compatibility, but should be after func diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index a30defdd77b..fb8c69e92bc 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -6,11 +6,8 @@ from typing import ( Any, Awaitable, Callable, - Dict, - List, Literal, Optional, - Type, Union, ) @@ -50,7 +47,7 @@ class StructuredTool(BaseTool): # TODO: Is this needed? async def ainvoke( self, - input: Union[str, Dict, ToolCall], + input: Union[str, dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -112,7 +109,7 @@ class StructuredTool(BaseTool): name: Optional[str] = None, description: Optional[str] = None, return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, + args_schema: Optional[type[BaseModel]] = None, infer_schema: bool = True, *, response_format: Literal["content", "content_and_artifact"] = "content", @@ -209,7 +206,7 @@ class StructuredTool(BaseTool): ) -def _filter_schema_args(func: Callable) -> List[str]: +def _filter_schema_args(func: Callable) -> list[str]: filter_args = list(FILTERED_ARGS) if config_param := _get_runnable_config_param(func): filter_args.append(config_param) diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index a0b94afb7b7..11bb8b3f1b8 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -8,8 +8,6 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, Sequence, Union, @@ -52,13 +50,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -93,13 +91,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -238,13 +236,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_chain_start( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, run_type: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, @@ -282,10 +280,10 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_chain_end( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Run: """End a trace for a chain run. @@ -313,7 +311,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): self, error: BaseException, *, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, run_id: UUID, **kwargs: Any, ) -> Run: @@ -340,15 +338,15 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Run: """Start a trace for a tool run. @@ -429,13 +427,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -556,13 +554,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Any: @@ -585,13 +583,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: llm_run = self._create_llm_run( @@ -642,7 +640,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: llm_run = self._complete_llm_run( @@ -658,7 +656,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: llm_run = self._errored_llm_run( @@ -670,13 +668,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_chain_start( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, run_type: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, @@ -697,10 +695,10 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_chain_end( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: chain_run = self._complete_chain_run( @@ -716,7 +714,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): self, error: BaseException, *, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, run_id: UUID, **kwargs: Any, ) -> None: @@ -731,15 +729,15 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: tool_run = self._create_tool_run( @@ -776,7 +774,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: tool_run = self._errored_tool_run( @@ -788,13 +786,13 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> None: @@ -819,7 +817,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: retrieval_run = self._errored_retrieval_run( @@ -839,7 +837,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: retrieval_run = self._complete_retrieval_run( diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 01659fe42aa..c252dad050c 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -6,10 +6,7 @@ from typing import ( TYPE_CHECKING, Any, Generator, - List, Optional, - Tuple, - Type, Union, cast, ) @@ -53,7 +50,7 @@ def tracing_v2_enabled( project_name: Optional[str] = None, *, example_id: Optional[Union[str, UUID]] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, client: Optional[LangSmithClient] = None, ) -> Generator[LangChainTracer, None, None]: """Instruct LangChain to log all runs in context to LangSmith. @@ -169,11 +166,11 @@ def _get_tracer_project() -> str: ) -_configure_hooks: List[ - Tuple[ +_configure_hooks: list[ + tuple[ ContextVar[Optional[BaseCallbackHandler]], bool, - Optional[Type[BaseCallbackHandler]], + Optional[type[BaseCallbackHandler]], Optional[str], ] ] = [] @@ -182,7 +179,7 @@ _configure_hooks: List[ def register_configure_hook( context_var: ContextVar[Optional[Any]], inheritable: bool, - handle_class: Optional[Type[BaseCallbackHandler]] = None, + handle_class: Optional[type[BaseCallbackHandler]] = None, env_var: Optional[str] = None, ) -> None: """Register a configure hook. diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index 15e23023bb0..53bc09a3e54 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -11,13 +11,9 @@ from typing import ( TYPE_CHECKING, Any, Coroutine, - Dict, - List, Literal, Optional, Sequence, - Set, - Tuple, Union, cast, ) @@ -81,9 +77,9 @@ class _TracerCore(ABC): """ super().__init__(**kwargs) self._schema_format = _schema_format # For internal use only API will change. - self.run_map: Dict[str, Run] = {} + self.run_map: dict[str, Run] = {} """Map of run ID to run. Cleared on run end.""" - self.order_map: Dict[UUID, Tuple[UUID, str]] = {} + self.order_map: dict[UUID, tuple[UUID, str]] = {} """Map of run ID to (trace_id, dotted_order). Cleared when tracer GCed.""" @abstractmethod @@ -137,7 +133,7 @@ class _TracerCore(ABC): self.run_map[str(run.id)] = run def _get_run( - self, run_id: UUID, run_type: Union[str, Set[str], None] = None + self, run_id: UUID, run_type: Union[str, set[str], None] = None ) -> Run: try: run = self.run_map[str(run_id)] @@ -145,7 +141,7 @@ class _TracerCore(ABC): raise TracerException(f"No indexed run ID {run_id}.") from exc if isinstance(run_type, str): - run_types: Union[Set[str], None] = {run_type} + run_types: Union[set[str], None] = {run_type} else: run_types = run_type if run_types is not None and run.run_type not in run_types: @@ -157,12 +153,12 @@ class _TracerCore(ABC): def _create_chat_model_run( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -200,12 +196,12 @@ class _TracerCore(ABC): def _create_llm_run( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -239,7 +235,7 @@ class _TracerCore(ABC): Append token event to LLM run and return the run. """ llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) - event_kwargs: Dict[str, Any] = {"token": token} + event_kwargs: dict[str, Any] = {"token": token} if chunk: event_kwargs["chunk"] = chunk llm_run.events.append( @@ -258,7 +254,7 @@ class _TracerCore(ABC): **kwargs: Any, ) -> Run: llm_run = self._get_run(run_id) - retry_d: Dict[str, Any] = { + retry_d: dict[str, Any] = { "slept": retry_state.idle_for, "attempt": retry_state.attempt_number, } @@ -306,12 +302,12 @@ class _TracerCore(ABC): def _create_chain_run( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, run_type: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, @@ -358,9 +354,9 @@ class _TracerCore(ABC): def _complete_chain_run( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Run: """Update a chain run with outputs and end time.""" @@ -375,7 +371,7 @@ class _TracerCore(ABC): def _errored_chain_run( self, error: BaseException, - inputs: Optional[Dict[str, Any]], + inputs: Optional[dict[str, Any]], run_id: UUID, **kwargs: Any, ) -> Run: @@ -389,14 +385,14 @@ class _TracerCore(ABC): def _create_tool_run( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Run: """Create a tool run.""" @@ -428,7 +424,7 @@ class _TracerCore(ABC): def _complete_tool_run( self, - output: Dict[str, Any], + output: dict[str, Any], run_id: UUID, **kwargs: Any, ) -> Run: @@ -454,12 +450,12 @@ class _TracerCore(ABC): def _create_retrieval_run( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: diff --git a/libs/core/langchain_core/tracers/evaluation.py b/libs/core/langchain_core/tracers/evaluation.py index ed6a2bfdb40..687e5994ef3 100644 --- a/libs/core/langchain_core/tracers/evaluation.py +++ b/libs/core/langchain_core/tracers/evaluation.py @@ -6,7 +6,7 @@ import logging import threading import weakref from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Optional, Sequence, Union, cast from uuid import UUID import langsmith @@ -96,7 +96,7 @@ class EvaluatorCallbackHandler(BaseTracer): self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.skip_unfinished = skip_unfinished self.project_name = project_name - self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {} + self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {} self.lock = threading.Lock() global _TRACERS _TRACERS.add(self) @@ -152,7 +152,7 @@ class EvaluatorCallbackHandler(BaseTracer): def _select_eval_results( self, results: Union[EvaluationResult, EvaluationResults], - ) -> List[EvaluationResult]: + ) -> list[EvaluationResult]: if isinstance(results, EvaluationResult): results_ = [results] elif isinstance(results, dict) and "results" in results: @@ -169,10 +169,10 @@ class EvaluatorCallbackHandler(BaseTracer): evaluator_response: Union[EvaluationResult, EvaluationResults], run: Run, source_run_id: Optional[UUID] = None, - ) -> List[EvaluationResult]: + ) -> list[EvaluationResult]: results = self._select_eval_results(evaluator_response) for res in results: - source_info_: Dict[str, Any] = {} + source_info_: dict[str, Any] = {} if res.evaluator_info: source_info_ = {**res.evaluator_info, **source_info_} run_id_ = getattr(res, "target_run_id", None) diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index a34a5e3528f..23739dcb731 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -8,7 +8,6 @@ from typing import ( TYPE_CHECKING, Any, AsyncIterator, - Dict, Iterator, List, Optional, @@ -66,14 +65,14 @@ class RunInfo(TypedDict): """ name: str - tags: List[str] - metadata: Dict[str, Any] + tags: list[str] + metadata: dict[str, Any] run_type: str inputs: NotRequired[Any] parent_run_id: Optional[UUID] -def _assign_name(name: Optional[str], serialized: Optional[Dict[str, Any]]) -> str: +def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str: """Assign a name to a run.""" if name is not None: return name @@ -107,15 +106,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand # Map of run ID to run info. # the entry corresponding to a given run id is cleaned # up when each corresponding run ends. - self.run_map: Dict[UUID, RunInfo] = {} + self.run_map: dict[UUID, RunInfo] = {} # The callback event that corresponds to the end of a parent run # may be invoked BEFORE the callback event that corresponds to the end # of a child run, which results in clean up of run_map. # So we keep track of the mapping between children and parent run IDs # in a separate container. This container is GCed when the tracer is GCed. - self.parent_map: Dict[UUID, Optional[UUID]] = {} + self.parent_map: dict[UUID, Optional[UUID]] = {} - self.is_tapped: Dict[UUID, Any] = {} + self.is_tapped: dict[UUID, Any] = {} # Filter which events will be sent over the queue. self.root_event_filter = _RootEventFilter( @@ -132,7 +131,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self.send_stream = memory_stream.get_send_stream() self.receive_stream = memory_stream.get_receive_stream() - def _get_parent_ids(self, run_id: UUID) -> List[str]: + def _get_parent_ids(self, run_id: UUID) -> list[str]: """Get the parent IDs of a run (non-recursively) cast to strings.""" parent_ids = [] @@ -269,8 +268,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self, run_id: UUID, *, - tags: Optional[List[str]], - metadata: Optional[Dict[str, Any]], + tags: Optional[list[str]], + metadata: Optional[dict[str, Any]], parent_run_id: Optional[UUID], name_: str, run_type: str, @@ -296,13 +295,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> None: @@ -337,13 +336,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_llm_start( self, - serialized: Dict[str, Any], - prompts: List[str], + serialized: dict[str, Any], + prompts: list[str], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> None: @@ -384,8 +383,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand data: Any, *, run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Generate a custom astream event.""" @@ -456,7 +455,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand run_info = self.run_map.pop(run_id) inputs_ = run_info["inputs"] - generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]] + generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]] output: Union[dict, BaseMessage] = {} if run_info["run_type"] == "chat_model": @@ -504,13 +503,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_chain_start( self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], + serialized: dict[str, Any], + inputs: dict[str, Any], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, run_type: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, @@ -552,10 +551,10 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_chain_end( self, - outputs: Dict[str, Any], + outputs: dict[str, Any], *, run_id: UUID, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """End a trace for a chain run.""" @@ -586,15 +585,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, - inputs: Optional[Dict[str, Any]] = None, + inputs: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: """Start a trace for a tool run.""" @@ -653,13 +652,13 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand async def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> None: diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 78b731a8252..15b3b436d89 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -6,7 +6,7 @@ import logging import warnings from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID from langsmith import Client @@ -91,7 +91,7 @@ class LangChainTracer(BaseTracer): example_id: Optional[Union[UUID, str]] = None, project_name: Optional[str] = None, client: Optional[Client] = None, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: """Initialize the LangChain tracer. @@ -114,13 +114,13 @@ class LangChainTracer(BaseTracer): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, run_id: UUID, - tags: Optional[List[str]] = None, + tags: Optional[list[str]] = None, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any, ) -> Run: @@ -194,7 +194,7 @@ class LangChainTracer(BaseTracer): ) raise ValueError("Failed to get run URL.") - def _get_tags(self, run: Run) -> List[str]: + def _get_tags(self, run: Run) -> list[str]: """Get combined tags for a run.""" tags = set(run.tags or []) tags.update(self.tags or []) diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index ac154df6225..a48f4233594 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -7,9 +7,7 @@ from collections import defaultdict from typing import ( Any, AsyncIterator, - Dict, Iterator, - List, Literal, Optional, Sequence, @@ -42,16 +40,16 @@ class LogEntry(TypedDict): """Name of the object being run.""" type: str """Type of the object being run, eg. prompt, chain, llm, etc.""" - tags: List[str] + tags: list[str] """List of tags for the run.""" - metadata: Dict[str, Any] + metadata: dict[str, Any] """Key-value pairs of metadata for the run.""" start_time: str """ISO-8601 timestamp of when the run started.""" - streamed_output_str: List[str] + streamed_output_str: list[str] """List of LLM tokens streamed by this run, if applicable.""" - streamed_output: List[Any] + streamed_output: list[Any] """List of output chunks streamed by this run, if available.""" inputs: NotRequired[Optional[Any]] """Inputs to this run. Not available currently via astream_log.""" @@ -69,7 +67,7 @@ class RunState(TypedDict): id: str """ID of the run.""" - streamed_output: List[Any] + streamed_output: list[Any] """List of output chunks streamed by Runnable.stream()""" final_output: Optional[Any] """Final output of the run, usually the result of aggregating (`+`) streamed_output. @@ -83,7 +81,7 @@ class RunState(TypedDict): # Do we want tags/metadata on the root run? Client kinda knows it in most situations # tags: List[str] - logs: Dict[str, LogEntry] + logs: dict[str, LogEntry] """Map of run names to sub-runs. If filters were supplied, this list will contain only the runs that matched the filters.""" @@ -91,14 +89,14 @@ class RunState(TypedDict): class RunLogPatch: """Patch to the run log.""" - ops: List[Dict[str, Any]] + ops: list[dict[str, Any]] """List of jsonpatch operations, which describe how to create the run state from an empty dict. This is the minimal representation of the log, designed to be serialized as JSON and sent over the wire to reconstruct the log on the other side. Reconstruction of the state can be done with any jsonpatch-compliant library, see https://jsonpatch.com for more information.""" - def __init__(self, *ops: Dict[str, Any]) -> None: + def __init__(self, *ops: dict[str, Any]) -> None: self.ops = list(ops) def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog: @@ -127,7 +125,7 @@ class RunLog(RunLogPatch): state: RunState """Current state of the log, obtained from applying all ops in sequence.""" - def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: + def __init__(self, *ops: dict[str, Any], state: RunState) -> None: super().__init__(*ops) self.state = state @@ -219,14 +217,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self.lock = threading.Lock() self.send_stream = memory_stream.get_send_stream() self.receive_stream = memory_stream.get_receive_stream() - self._key_map_by_run_id: Dict[UUID, str] = {} - self._counter_map_by_name: Dict[str, int] = defaultdict(int) + self._key_map_by_run_id: dict[UUID, str] = {} + self._counter_map_by_name: dict[str, int] = defaultdict(int) self.root_id: Optional[UUID] = None def __aiter__(self) -> AsyncIterator[RunLogPatch]: return self.receive_stream.__aiter__() - def send(self, *ops: Dict[str, Any]) -> bool: + def send(self, *ops: dict[str, Any]) -> bool: """Send a patch to the stream, return False if the stream is closed. Args: @@ -477,7 +475,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): def _get_standardized_inputs( run: Run, schema_format: Literal["original", "streaming_events"] -) -> Optional[Dict[str, Any]]: +) -> Optional[dict[str, Any]]: """Extract standardized inputs from a run. Standardizes the inputs based on the type of the runnable used. @@ -631,7 +629,7 @@ async def _astream_log_implementation( except TypeError: prev_final_output = None final_output = chunk - patches: List[Dict[str, Any]] = [] + patches: list[dict[str, Any]] = [] if with_streamed_output_list: patches.append( { diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index 8e3692d7ac1..72ae7d14b17 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -4,7 +4,7 @@ from __future__ import annotations import datetime import warnings -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from uuid import UUID from langsmith.schemas import RunBase as BaseRunV2 @@ -18,7 +18,7 @@ from langchain_core._api import deprecated @deprecated("0.1.0", alternative="Use string instead.", removal="1.0") -def RunTypeEnum() -> Type[RunTypeEnumDep]: +def RunTypeEnum() -> type[RunTypeEnumDep]: """RunTypeEnum.""" warnings.warn( "RunTypeEnum is deprecated. Please directly use a string instead" @@ -35,7 +35,7 @@ class TracerSessionV1Base(BaseModelV1): start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) name: Optional[str] = None - extra: Optional[Dict[str, Any]] = None + extra: Optional[dict[str, Any]] = None @deprecated("0.1.0", removal="1.0") @@ -72,10 +72,10 @@ class BaseRun(BaseModelV1): parent_uuid: Optional[str] = None start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow) - extra: Optional[Dict[str, Any]] = None + extra: Optional[dict[str, Any]] = None execution_order: int child_execution_order: int - serialized: Dict[str, Any] + serialized: dict[str, Any] session_id: int error: Optional[str] = None @@ -84,7 +84,7 @@ class BaseRun(BaseModelV1): class LLMRun(BaseRun): """Class for LLMRun.""" - prompts: List[str] + prompts: list[str] # Temporarily, remove but we will completely remove LLMRun # response: Optional[LLMResult] = None @@ -93,11 +93,11 @@ class LLMRun(BaseRun): class ChainRun(BaseRun): """Class for ChainRun.""" - inputs: Dict[str, Any] - outputs: Optional[Dict[str, Any]] = None - child_llm_runs: List[LLMRun] = FieldV1(default_factory=list) - child_chain_runs: List[ChainRun] = FieldV1(default_factory=list) - child_tool_runs: List[ToolRun] = FieldV1(default_factory=list) + inputs: dict[str, Any] + outputs: Optional[dict[str, Any]] = None + child_llm_runs: list[LLMRun] = FieldV1(default_factory=list) + child_chain_runs: list[ChainRun] = FieldV1(default_factory=list) + child_tool_runs: list[ToolRun] = FieldV1(default_factory=list) @deprecated("0.1.0", alternative="Run", removal="1.0") @@ -107,9 +107,9 @@ class ToolRun(BaseRun): tool_input: str output: Optional[str] = None action: str - child_llm_runs: List[LLMRun] = FieldV1(default_factory=list) - child_chain_runs: List[ChainRun] = FieldV1(default_factory=list) - child_tool_runs: List[ToolRun] = FieldV1(default_factory=list) + child_llm_runs: list[LLMRun] = FieldV1(default_factory=list) + child_chain_runs: list[ChainRun] = FieldV1(default_factory=list) + child_tool_runs: list[ToolRun] = FieldV1(default_factory=list) # Begin V2 API Schemas @@ -126,9 +126,9 @@ class Run(BaseRunV2): dotted_order: The dotted order. """ - child_runs: List[Run] = FieldV1(default_factory=list) - tags: Optional[List[str]] = FieldV1(default_factory=list) - events: List[Dict[str, Any]] = FieldV1(default_factory=list) + child_runs: list[Run] = FieldV1(default_factory=list) + tags: Optional[list[str]] = FieldV1(default_factory=list) + events: list[dict[str, Any]] = FieldV1(default_factory=list) trace_id: Optional[UUID] = None dotted_order: Optional[str] = None diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index d058b8041aa..36e823c9144 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Optional -def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]: +def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]: """Merge many dicts, handling specific scenarios where a key exists in both dictionaries but has a value of None in 'left'. In such cases, the method uses the value from 'right' for that key in the merged dictionary. @@ -69,7 +69,7 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any] return merged -def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]: +def merge_lists(left: Optional[list], *others: Optional[list]) -> Optional[list]: """Add many lists, handling None. Args: diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index 6c5cff88819..ee03ce9d3a9 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union def env_var_is_set(env_var: str) -> bool: @@ -22,8 +22,8 @@ def env_var_is_set(env_var: str) -> bool: def get_from_dict_or_env( - data: Dict[str, Any], - key: Union[str, List[str]], + data: dict[str, Any], + key: Union[str, list[str]], env_key: str, default: Optional[str] = None, ) -> str: diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 005ec23a555..234281966ce 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -85,7 +85,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict: removal="1.0", ) def convert_pydantic_to_openai_function( - model: Type, + model: type, *, name: Optional[str] = None, description: Optional[str] = None, @@ -106,8 +106,10 @@ def convert_pydantic_to_openai_function( """ if hasattr(model, "model_json_schema"): schema = model.model_json_schema() # Pydantic 2 - else: + elif hasattr(model, "schema"): schema = model.schema() # Pydantic 1 + else: + raise TypeError("Model must be a Pydantic model.") schema = dereference_refs(schema) if "definitions" in schema: # pydantic 1 schema.pop("definitions", None) @@ -128,7 +130,7 @@ def convert_pydantic_to_openai_function( removal="1.0", ) def convert_pydantic_to_openai_tool( - model: Type[BaseModel], + model: type[BaseModel], *, name: Optional[str] = None, description: Optional[str] = None, @@ -194,8 +196,8 @@ def convert_python_function_to_openai_function( ) -def _convert_typed_dict_to_openai_function(typed_dict: Type) -> FunctionDescription: - visited: Dict = {} +def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescription: + visited: dict = {} from pydantic.v1 import BaseModel model = cast( @@ -209,11 +211,11 @@ _MAX_TYPED_DICT_RECURSION = 25 def _convert_any_typed_dicts_to_pydantic( - type_: Type, + type_: type, *, - visited: Dict, + visited: dict, depth: int = 0, -) -> Type: +) -> type: from pydantic.v1 import Field as Field_v1 from pydantic.v1 import create_model as create_model_v1 @@ -267,9 +269,9 @@ def _convert_any_typed_dicts_to_pydantic( subscriptable_origin = _py_38_safe_origin(origin) type_args = tuple( _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) - for arg in type_args + for arg in type_args # type: ignore[index] ) - return subscriptable_origin[type_args] + return subscriptable_origin[type_args] # type: ignore[index] else: return type_ @@ -333,10 +335,10 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: def convert_to_openai_function( - function: Union[Dict[str, Any], Type, Callable, BaseTool], + function: Union[dict[str, Any], type, Callable, BaseTool], *, strict: Optional[bool] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Convert a raw function/class to an OpenAI function. .. versionchanged:: 0.2.29 @@ -411,10 +413,10 @@ def convert_to_openai_function( def convert_to_openai_tool( - tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool], *, strict: Optional[bool] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Convert a raw function/class to an OpenAI tool. .. versionchanged:: 0.2.29 @@ -441,13 +443,13 @@ def convert_to_openai_tool( if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: return tool oai_function = convert_to_openai_function(tool, strict=strict) - oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function} + oai_tool: dict[str, Any] = {"type": "function", "function": oai_function} return oai_tool def tool_example_to_messages( - input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None -) -> List[BaseMessage]: + input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None +) -> list[BaseMessage]: """Convert an example into a list of messages that can be fed into an LLM. This code is an adapter that converts a single example to a list of messages @@ -511,7 +513,7 @@ def tool_example_to_messages( tool_example_to_messages(txt, [tool_call]) ) """ - messages: List[BaseMessage] = [HumanMessage(content=input)] + messages: list[BaseMessage] = [HumanMessage(content=input)] openai_tool_calls = [] for tool_call in tool_calls: openai_tool_calls.append( @@ -540,10 +542,10 @@ def tool_example_to_messages( def _parse_google_docstring( docstring: Optional[str], - args: List[str], + args: list[str], *, error_on_invalid_docstring: bool = False, -) -> Tuple[str, dict]: +) -> tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. Assumes the function docstring follows Google Python style guide. @@ -590,12 +592,12 @@ def _parse_google_docstring( return description, arg_descriptions -def _py_38_safe_origin(origin: Type) -> Type: - origin_union_type_map: Dict[Type, Any] = ( +def _py_38_safe_origin(origin: type) -> type: + origin_union_type_map: dict[type, Any] = ( {types.UnionType: Union} if hasattr(types, "UnionType") else {} ) - origin_map: Dict[Type, Any] = { + origin_map: dict[type, Any] = { dict: Dict, list: List, tuple: Tuple, @@ -610,8 +612,8 @@ def _py_38_safe_origin(origin: Type) -> Type: def _recursive_set_additional_properties_false( - schema: Dict[str, Any], -) -> Dict[str, Any]: + schema: dict[str, Any], +) -> dict[str, Any]: if isinstance(schema, dict): # Check if 'required' is a key at the current level or if the schema is empty, # in which case additionalProperties still needs to be specified. diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index a45d59cd232..c721f5fcae3 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -2,7 +2,7 @@ from __future__ import annotations import json import re -from typing import Any, Callable, List +from typing import Any, Callable from langchain_core.exceptions import OutputParserException @@ -163,7 +163,7 @@ def _parse_json( return parser(json_str) -def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: """ Parse a JSON string from a Markdown string and check that it contains the expected keys. diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index be5c8e25ba8..172a741be66 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Set +from typing import Any, Optional, Sequence def _retrieve_ref(path: str, schema: dict) -> dict: @@ -24,9 +24,9 @@ def _retrieve_ref(path: str, schema: dict) -> dict: def _dereference_refs_helper( obj: Any, - full_schema: Dict[str, Any], + full_schema: dict[str, Any], skip_keys: Sequence[str], - processed_refs: Optional[Set[str]] = None, + processed_refs: Optional[set[str]] = None, ) -> Any: if processed_refs is None: processed_refs = set() @@ -63,8 +63,8 @@ def _dereference_refs_helper( def _infer_skip_keys( - obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None -) -> List[str]: + obj: Any, full_schema: dict, processed_refs: Optional[set[str]] = None +) -> list[str]: if processed_refs is None: processed_refs = set() diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index f8b8a93cb12..3eb742529ef 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -16,7 +16,6 @@ from typing import ( Mapping, Optional, Sequence, - Tuple, Union, cast, ) @@ -45,7 +44,7 @@ class ChevronError(SyntaxError): # -def grab_literal(template: str, l_del: str) -> Tuple[str, str]: +def grab_literal(template: str, l_del: str) -> tuple[str, str]: """Parse a literal from the template. Args: @@ -124,7 +123,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: return False -def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: +def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]: """Parse a tag from a template. Args: @@ -201,7 +200,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s def tokenize( template: str, def_ldel: str = "{{", def_rdel: str = "}}" -) -> Iterator[Tuple[str, str]]: +) -> Iterator[tuple[str, str]]: """Tokenize a mustache template. Tokenizes a mustache template in a generator fashion, @@ -427,13 +426,13 @@ def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str: # # The main rendering function # -g_token_cache: Dict[str, List[Tuple[str, str]]] = {} +g_token_cache: dict[str, list[tuple[str, str]]] = {} EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) def render( - template: Union[str, List[Tuple[str, str]]] = "", + template: Union[str, list[tuple[str, str]]] = "", data: Mapping[str, Any] = EMPTY_DICT, partials_dict: Mapping[str, str] = EMPTY_DICT, padding: str = "", @@ -490,7 +489,7 @@ def render( if isinstance(template, Sequence) and not isinstance(template, str): # Then we don't need to tokenize it # But it does need to be a generator - tokens: Iterator[Tuple[str, str]] = (token for token in template) + tokens: Iterator[tuple[str, str]] = (token for token in template) else: if template in g_token_cache: tokens = (token for token in g_token_cache[template]) @@ -561,7 +560,7 @@ def render( if callable(scope): # Generate template text from tags text = "" - tags: List[Tuple[str, str]] = [] + tags: list[tuple[str, str]] = [] for token in tokens: if token == ("end", key): break diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 74602b383d4..ac44ac1ebcc 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -11,7 +11,6 @@ from typing import ( Any, Callable, Dict, - List, Optional, Type, TypeVar, @@ -71,7 +70,7 @@ else: TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel) -def is_pydantic_v1_subclass(cls: Type) -> bool: +def is_pydantic_v1_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" if PYDANTIC_MAJOR_VERSION == 1: return True @@ -83,14 +82,14 @@ def is_pydantic_v1_subclass(cls: Type) -> bool: return False -def is_pydantic_v2_subclass(cls: Type) -> bool: +def is_pydantic_v2_subclass(cls: type) -> bool: """Check if the installed Pydantic version is 1.x-like.""" from pydantic import BaseModel return PYDANTIC_MAJOR_VERSION == 2 and issubclass(cls, BaseModel) -def is_basemodel_subclass(cls: Type) -> bool: +def is_basemodel_subclass(cls: type) -> bool: """Check if the given class is a subclass of Pydantic BaseModel. Check if the given class is a subclass of any of the following: @@ -166,7 +165,7 @@ def pre_init(func: Callable) -> Any: @root_validator(pre=True) @wraps(func) - def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]: + def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]: """Decorator to run a function before model initialization. Args: @@ -218,12 +217,12 @@ class _IgnoreUnserializable(GenerateJsonSchema): def _create_subset_model_v1( name: str, - model: Type[BaseModel], + model: type[BaseModel], field_names: list, *, descriptions: Optional[dict] = None, fn_description: Optional[str] = None, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" if PYDANTIC_MAJOR_VERSION == 1: from pydantic import create_model @@ -256,12 +255,12 @@ def _create_subset_model_v1( def _create_subset_model_v2( name: str, - model: Type[pydantic.BaseModel], - field_names: List[str], + model: type[pydantic.BaseModel], + field_names: list[str], *, descriptions: Optional[dict] = None, fn_description: Optional[str] = None, -) -> Type[pydantic.BaseModel]: +) -> type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" from pydantic import create_model from pydantic.fields import FieldInfo @@ -299,11 +298,11 @@ def _create_subset_model_v2( def _create_subset_model( name: str, model: TypeBaseModel, - field_names: List[str], + field_names: list[str], *, descriptions: Optional[dict] = None, fn_description: Optional[str] = None, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create subset model using the same pydantic version as the input model.""" if PYDANTIC_MAJOR_VERSION == 1: return _create_subset_model_v1( @@ -344,25 +343,25 @@ if PYDANTIC_MAJOR_VERSION == 2: from pydantic.v1 import BaseModel as BaseModelV1 @overload - def get_fields(model: Type[BaseModelV2]) -> Dict[str, FieldInfoV2]: ... + def get_fields(model: type[BaseModelV2]) -> dict[str, FieldInfoV2]: ... @overload - def get_fields(model: BaseModelV2) -> Dict[str, FieldInfoV2]: ... + def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ... @overload - def get_fields(model: Type[BaseModelV1]) -> Dict[str, FieldInfoV1]: ... + def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ... @overload - def get_fields(model: BaseModelV1) -> Dict[str, FieldInfoV1]: ... + def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ... def get_fields( model: Union[ BaseModelV2, BaseModelV1, - Type[BaseModelV2], - Type[BaseModelV1], + type[BaseModelV2], + type[BaseModelV1], ], - ) -> Union[Dict[str, FieldInfoV2], Dict[str, FieldInfoV1]]: + ) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]: """Get the field names of a Pydantic model.""" if hasattr(model, "model_fields"): return model.model_fields # type: ignore @@ -375,8 +374,8 @@ elif PYDANTIC_MAJOR_VERSION == 1: from pydantic import BaseModel as BaseModelV1_ def get_fields( # type: ignore[no-redef] - model: Union[Type[BaseModelV1_], BaseModelV1_], - ) -> Dict[str, FieldInfoV1]: + model: Union[type[BaseModelV1_], BaseModelV1_], + ) -> dict[str, FieldInfoV1]: """Get the field names of a Pydantic model.""" return model.__fields__ # type: ignore else: @@ -394,14 +393,14 @@ def _create_root_model( type_: Any, module_name: Optional[str] = None, default_: object = NO_DEFAULT, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create a base class.""" def schema( - cls: Type[BaseModel], + cls: type[BaseModel], by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: # Complains about schema not being defined in superclass schema_ = super(cls, cls).schema( # type: ignore[misc] by_alias=by_alias, ref_template=ref_template @@ -410,12 +409,12 @@ def _create_root_model( return schema_ def model_json_schema( - cls: Type[BaseModel], + cls: type[BaseModel], by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, mode: JsonSchemaMode = "validation", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: # Complains about model_json_schema not being defined in superclass schema_ = super(cls, cls).model_json_schema( # type: ignore[misc] by_alias=by_alias, @@ -452,7 +451,7 @@ def _create_root_model_cached( *, module_name: Optional[str] = None, default_: object = NO_DEFAULT, -) -> Type[BaseModel]: +) -> type[BaseModel]: return _create_root_model( model_name, type_, default_=default_, module_name=module_name ) @@ -462,7 +461,7 @@ def _create_root_model_cached( def _create_model_cached( __model_name: str, **field_definitions: Any, -) -> Type[BaseModel]: +) -> type[BaseModel]: return _create_model_base( __model_name, __config__=_SchemaConfig, @@ -474,7 +473,7 @@ def create_model( __model_name: str, __module_name: Optional[str] = None, **field_definitions: Any, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create a pydantic model with the given field definitions. Please use create_model_v2 instead of this function. @@ -513,7 +512,7 @@ def create_model( _RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")} -def _remap_field_definitions(field_definitions: Dict[str, Any]) -> Dict[str, Any]: +def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]: """This remaps fields to avoid colliding with internal pydantic fields.""" from pydantic import Field from pydantic.fields import FieldInfo @@ -547,9 +546,9 @@ def create_model_v2( model_name: str, *, module_name: Optional[str] = None, - field_definitions: Optional[Dict[str, Any]] = None, + field_definitions: Optional[dict[str, Any]] = None, root: Optional[Any] = None, -) -> Type[BaseModel]: +) -> type[BaseModel]: """Create a pydantic model with the given field definitions. Attention: diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 737b5c32d7d..5fca50cfb76 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -32,13 +32,9 @@ from typing import ( Callable, ClassVar, Collection, - Dict, Iterable, - List, Optional, Sequence, - Tuple, - Type, TypeVar, ) @@ -66,13 +62,13 @@ class VectorStore(ABC): def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, # One of the kwargs should be `ids` which is a list of ids # associated with the texts. # This is not yet enforced in the type signature for backwards compatibility # with existing implementations. **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Run more texts through the embeddings and add to the vectorstore. Args: @@ -124,7 +120,7 @@ class VectorStore(ABC): ) return None - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: """Delete by vector ID or other criteria. Args: @@ -138,7 +134,7 @@ class VectorStore(ABC): raise NotImplementedError("delete method must be implemented by subclass.") - def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Get documents by their IDs. The returned documents are expected to have the ID field set to the ID of the @@ -167,7 +163,7 @@ class VectorStore(ABC): ) # Implementations should override this method to provide an async native version. - async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]: + async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Async get documents by their IDs. The returned documents are expected to have the ID field set to the ID of the @@ -194,7 +190,7 @@ class VectorStore(ABC): return await run_in_executor(None, self.get_by_ids, ids) async def adelete( - self, ids: Optional[List[str]] = None, **kwargs: Any + self, ids: Optional[list[str]] = None, **kwargs: Any ) -> Optional[bool]: """Async delete by vector ID or other criteria. @@ -211,9 +207,9 @@ class VectorStore(ABC): async def aadd_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Async run more texts through the embeddings and add to the vectorstore. Args: @@ -254,7 +250,7 @@ class VectorStore(ABC): return await self.aadd_documents(docs, **kwargs) return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs) - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: """Add or update documents in the vectorstore. Args: @@ -287,8 +283,8 @@ class VectorStore(ABC): ) async def aadd_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[str]: + self, documents: list[Document], **kwargs: Any + ) -> list[str]: """Async run more documents through the embeddings and add to the vectorstore. @@ -318,7 +314,7 @@ class VectorStore(ABC): return await run_in_executor(None, self.add_documents, documents, **kwargs) - def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: + def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]: """Return docs most similar to query using a specified search type. Args: @@ -352,7 +348,7 @@ class VectorStore(ABC): async def asearch( self, query: str, search_type: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: """Async return docs most similar to query using a specified search type. Args: @@ -386,7 +382,7 @@ class VectorStore(ABC): @abstractmethod def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: """Return docs most similar to query. Args: @@ -442,7 +438,7 @@ class VectorStore(ABC): def similarity_search_with_score( self, *args: Any, **kwargs: Any - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Run similarity search with distance. Args: @@ -456,7 +452,7 @@ class VectorStore(ABC): async def asimilarity_search_with_score( self, *args: Any, **kwargs: Any - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Async run similarity search with distance. Args: @@ -479,7 +475,7 @@ class VectorStore(ABC): query: str, k: int = 4, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """ Default similarity search with relevance scores. Modify if necessary in subclass. @@ -506,7 +502,7 @@ class VectorStore(ABC): query: str, k: int = 4, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """ Default similarity search with relevance scores. Modify if necessary in subclass. @@ -533,7 +529,7 @@ class VectorStore(ABC): query: str, k: int = 4, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar. @@ -581,7 +577,7 @@ class VectorStore(ABC): query: str, k: int = 4, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Async return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar. @@ -626,7 +622,7 @@ class VectorStore(ABC): async def asimilarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: """Async return docs most similar to query. Args: @@ -644,8 +640,8 @@ class VectorStore(ABC): return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs) def similarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: """Return docs most similar to embedding vector. Args: @@ -659,8 +655,8 @@ class VectorStore(ABC): raise NotImplementedError async def asimilarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: """Async return docs most similar to embedding vector. Args: @@ -686,7 +682,7 @@ class VectorStore(ABC): fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -715,7 +711,7 @@ class VectorStore(ABC): fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Async return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -750,12 +746,12 @@ class VectorStore(ABC): def max_marginal_relevance_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -779,12 +775,12 @@ class VectorStore(ABC): async def amax_marginal_relevance_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Async return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity @@ -816,8 +812,8 @@ class VectorStore(ABC): @classmethod def from_documents( - cls: Type[VST], - documents: List[Document], + cls: type[VST], + documents: list[Document], embedding: Embeddings, **kwargs: Any, ) -> VST: @@ -837,8 +833,8 @@ class VectorStore(ABC): @classmethod async def afrom_documents( - cls: Type[VST], - documents: List[Document], + cls: type[VST], + documents: list[Document], embedding: Embeddings, **kwargs: Any, ) -> VST: @@ -859,10 +855,10 @@ class VectorStore(ABC): @classmethod @abstractmethod def from_texts( - cls: Type[VST], - texts: List[str], + cls: type[VST], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> VST: """Return VectorStore initialized from texts and embeddings. @@ -880,10 +876,10 @@ class VectorStore(ABC): @classmethod async def afrom_texts( - cls: Type[VST], - texts: List[str], + cls: type[VST], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> VST: """Async return VectorStore initialized from texts and embeddings. @@ -902,7 +898,7 @@ class VectorStore(ABC): None, cls.from_texts, texts, embedding, metadatas, **kwargs ) - def _get_retriever_tags(self) -> List[str]: + def _get_retriever_tags(self) -> list[str]: """Get tags for retriever.""" tags = [self.__class__.__name__] if self.embeddings: @@ -991,7 +987,7 @@ class VectorStoreRetriever(BaseRetriever): @model_validator(mode="before") @classmethod - def validate_search_type(cls, values: Dict) -> Any: + def validate_search_type(cls, values: dict) -> Any: """Validate search type. Args: @@ -1040,7 +1036,7 @@ class VectorStoreRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, **self.search_kwargs) elif self.search_type == "similarity_score_threshold": @@ -1060,7 +1056,7 @@ class VectorStoreRetriever(BaseRetriever): async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> list[Document]: if self.search_type == "similarity": docs = await self.vectorstore.asimilarity_search( query, **self.search_kwargs @@ -1080,7 +1076,7 @@ class VectorStoreRetriever(BaseRetriever): raise ValueError(f"search_type of {self.search_type} not allowed.") return docs - def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: """Add documents to the vectorstore. Args: @@ -1093,8 +1089,8 @@ class VectorStoreRetriever(BaseRetriever): return self.vectorstore.add_documents(documents, **kwargs) async def aadd_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[str]: + self, documents: list[Document], **kwargs: Any + ) -> list[str]: """Async add documents to the vectorstore. Args: diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index 147415178e1..d359f619685 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -7,12 +7,9 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Iterator, - List, Optional, Sequence, - Tuple, ) from langchain_core._api import deprecated @@ -153,7 +150,7 @@ class InMemoryVectorStore(VectorStore): """ # TODO: would be nice to change to # Dict[str, Document] at some point (will be a breaking change) - self.store: Dict[str, Dict[str, Any]] = {} + self.store: dict[str, dict[str, Any]] = {} self.embedding = embedding @property @@ -170,10 +167,10 @@ class InMemoryVectorStore(VectorStore): def add_documents( self, - documents: List[Document], - ids: Optional[List[str]] = None, + documents: list[Document], + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Add documents to the store.""" texts = [doc.page_content for doc in documents] vectors = self.embedding.embed_documents(texts) @@ -204,8 +201,8 @@ class InMemoryVectorStore(VectorStore): return ids_ async def aadd_documents( - self, documents: List[Document], ids: Optional[List[str]] = None, **kwargs: Any - ) -> List[str]: + self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any + ) -> list[str]: """Add documents to the store.""" texts = [doc.page_content for doc in documents] vectors = await self.embedding.aembed_documents(texts) @@ -219,7 +216,7 @@ class InMemoryVectorStore(VectorStore): id_iterator: Iterator[Optional[str]] = ( iter(ids) if ids else iter(doc.id for doc in documents) ) - ids_: List[str] = [] + ids_: list[str] = [] for doc, vector in zip(documents, vectors): doc_id = next(id_iterator) @@ -234,7 +231,7 @@ class InMemoryVectorStore(VectorStore): return ids_ - def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Get documents by their ids. Args: @@ -313,7 +310,7 @@ class InMemoryVectorStore(VectorStore): "failed": [], } - async def aget_by_ids(self, ids: Sequence[str], /) -> List[Document]: + async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Async get documents by their ids. Args: @@ -326,11 +323,11 @@ class InMemoryVectorStore(VectorStore): def _similarity_search_with_score_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, filter: Optional[Callable[[Document], bool]] = None, **kwargs: Any, - ) -> List[Tuple[Document, float, List[float]]]: + ) -> list[tuple[Document, float, list[float]]]: result = [] for doc in self.store.values(): vector = doc["vector"] @@ -351,11 +348,11 @@ class InMemoryVectorStore(VectorStore): def similarity_search_with_score_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, filter: Optional[Callable[[Document], bool]] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: return [ (doc, similarity) for doc, similarity, _ in self._similarity_search_with_score_by_vector( @@ -368,7 +365,7 @@ class InMemoryVectorStore(VectorStore): query: str, k: int = 4, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: embedding = self.embedding.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding, @@ -379,7 +376,7 @@ class InMemoryVectorStore(VectorStore): async def asimilarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: embedding = await self.embedding.aembed_query(query) docs = self.similarity_search_with_score_by_vector( embedding, @@ -390,10 +387,10 @@ class InMemoryVectorStore(VectorStore): def similarity_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: docs_and_scores = self.similarity_search_with_score_by_vector( embedding, k, @@ -402,18 +399,18 @@ class InMemoryVectorStore(VectorStore): return [doc for doc, _ in docs_and_scores] async def asimilarity_search_by_vector( - self, embedding: List[float], k: int = 4, **kwargs: Any - ) -> List[Document]: + self, embedding: list[float], k: int = 4, **kwargs: Any + ) -> list[Document]: return self.similarity_search_by_vector(embedding, k, **kwargs) def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)] async def asimilarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: return [ doc for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs) @@ -421,12 +418,12 @@ class InMemoryVectorStore(VectorStore): def max_marginal_relevance_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: prefetch_hits = self._similarity_search_with_score_by_vector( embedding=embedding, k=fetch_k, @@ -456,7 +453,7 @@ class InMemoryVectorStore(VectorStore): fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: embedding_vector = self.embedding.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding_vector, @@ -473,7 +470,7 @@ class InMemoryVectorStore(VectorStore): fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: embedding_vector = await self.embedding.aembed_query(query) return self.max_marginal_relevance_search_by_vector( embedding_vector, @@ -486,9 +483,9 @@ class InMemoryVectorStore(VectorStore): @classmethod def from_texts( cls, - texts: List[str], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> InMemoryVectorStore: store = cls( @@ -500,9 +497,9 @@ class InMemoryVectorStore(VectorStore): @classmethod async def afrom_texts( cls, - texts: List[str], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> InMemoryVectorStore: store = cls( diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 73b9aac9cbe..89fe0149e82 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -76,7 +76,7 @@ def maximal_marginal_relevance( embedding_list: list, lambda_mult: float = 0.5, k: int = 4, -) -> List[int]: +) -> list[int]: """Calculate maximal marginal relevance. Args: diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index 5a3065d44a4..79f414b1bcb 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1186,7 +1186,7 @@ develop = true [package.dependencies] httpx = "^0.27.0" -langchain-core = ">=0.3.0.dev1" +langchain-core = "^0.3.0" pytest = ">=7,<9" syrupy = "^4" @@ -1196,7 +1196,7 @@ url = "../standard-tests" [[package]] name = "langchain-text-splitters" -version = "0.3.0.dev1" +version = "0.3.0" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.9,<4.0" @@ -1204,7 +1204,7 @@ files = [] develop = true [package.dependencies] -langchain-core = "^0.3.0.dev1" +langchain-core = "^0.3.0" [package.source] type = "directory" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 89e46905f9b..96abaa5db89 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -41,8 +41,8 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "B", "E", "F", "I", "T201", "UP",] -ignore = [ "UP006", "UP007",] +select = ["B", "E", "F", "I", "T201", "UP"] +ignore = ["UP007"] [tool.coverage.run] omit = [ "tests/*",] diff --git a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py index f1f04c526cc..748f7b25c88 100644 --- a/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_async_base_tracer.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any, List +from typing import Any from uuid import uuid4 import pytest @@ -26,7 +26,7 @@ class FakeAsyncTracer(AsyncBaseTracer): def __init__(self) -> None: """Initialize the tracer.""" super().__init__() - self.runs: List[Run] = [] + self.runs: list[Run] = [] async def _persist_run(self, run: Run) -> None: self.runs.append(run) diff --git a/libs/core/tests/unit_tests/tracers/test_base_tracer.py b/libs/core/tests/unit_tests/tracers/test_base_tracer.py index b9ac507eb17..e11b5552a3b 100644 --- a/libs/core/tests/unit_tests/tracers/test_base_tracer.py +++ b/libs/core/tests/unit_tests/tracers/test_base_tracer.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Any, List +from typing import Any from unittest.mock import MagicMock from uuid import uuid4 @@ -30,7 +30,7 @@ class FakeTracer(BaseTracer): def __init__(self) -> None: """Initialize the tracer.""" super().__init__() - self.runs: List[Run] = [] + self.runs: list[Run] = [] def _persist_run(self, run: Run) -> None: """Persist a run.""" diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 971315752b8..91983b15887 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -7,7 +7,7 @@ the relevant methods. from __future__ import annotations import uuid -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -18,19 +18,19 @@ class CustomAddTextsVectorstore(VectorStore): """A vectorstore that only implements add texts.""" def __init__(self) -> None: - self.store: Dict[str, Document] = {} + self.store: dict[str, Document] = {} def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, # One of the kwargs should be `ids` which is a list of ids # associated with the texts. # This is not yet enforced in the type signature for backwards compatibility # with existing implementations. - ids: Optional[List[str]] = None, + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: if not isinstance(texts, list): texts = list(texts) ids_iter = iter(ids or []) @@ -46,14 +46,14 @@ class CustomAddTextsVectorstore(VectorStore): ids_.append(id_) return ids_ - def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: return [self.store[id] for id in ids if id in self.store] def from_texts( # type: ignore cls, - texts: List[str], + texts: list[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[list[dict]] = None, **kwargs: Any, ) -> CustomAddTextsVectorstore: vectorstore = CustomAddTextsVectorstore() @@ -62,7 +62,7 @@ class CustomAddTextsVectorstore(VectorStore): def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError()